1
mirror of https://github.com/home-assistant/core synced 2024-08-06 09:34:49 +02:00

Make async_track_template_result track multiple templates (#39371)

* Make async_track_template_result track multiple templates

Combine template entity updates to only write ha
state once per template group update

* Make async_track_template_result use dataclasses for input/output

* black versions

* naming
This commit is contained in:
J. Nick Koston 2020-08-31 19:07:40 -05:00 committed by GitHub
parent 8d68963854
commit a77e09b2c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 508 additions and 215 deletions

View File

@ -21,6 +21,7 @@ from homeassistant.exceptions import TemplateError
from homeassistant.helpers import condition from homeassistant.helpers import condition
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
TrackTemplate,
async_track_state_change_event, async_track_state_change_event,
async_track_template_result, async_track_template_result,
) )
@ -187,7 +188,10 @@ class BayesianBinarySensor(BinarySensorEntity):
) )
@callback @callback
def _async_template_result_changed(event, template, last_result, result): def _async_template_result_changed(event, updates):
track_template_result = updates.pop()
template = track_template_result.template
result = track_template_result.result
entity = event and event.data.get("entity_id") entity = event and event.data.get("entity_id")
if isinstance(result, TemplateError): if isinstance(result, TemplateError):
@ -215,7 +219,9 @@ class BayesianBinarySensor(BinarySensorEntity):
for template in self.observations_by_template: for template in self.observations_by_template:
info = async_track_template_result( info = async_track_template_result(
self.hass, template, _async_template_result_changed self.hass,
[TrackTemplate(template, None)],
_async_template_result_changed,
) )
self._callbacks.append(info) self._callbacks.append(info)

View File

@ -1,7 +1,7 @@
"""TemplateEntity utility class.""" """TemplateEntity utility class."""
import logging import logging
from typing import Any, Callable, Optional, Union from typing import Any, Callable, List, Optional, Union
import voluptuous as vol import voluptuous as vol
@ -9,7 +9,12 @@ from homeassistant.core import EVENT_HOMEASSISTANT_START, CoreState, callback
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import Event, async_track_template_result from homeassistant.helpers.event import (
Event,
TrackTemplate,
TrackTemplateResult,
async_track_template_result,
)
from homeassistant.helpers.template import Template, result_as_boolean from homeassistant.helpers.template import Template, result_as_boolean
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -34,7 +39,6 @@ class _TemplateAttribute:
self.validator = validator self.validator = validator
self.on_update = on_update self.on_update = on_update
self.async_update = None self.async_update = None
self.add_complete = False
self.none_on_template_error = none_on_template_error self.none_on_template_error = none_on_template_error
@callback @callback
@ -54,21 +58,14 @@ class _TemplateAttribute:
setattr(self._entity, self._attribute, attr_result) setattr(self._entity, self._attribute, attr_result)
@callback @callback
def _write_update_if_added(self): def handle_result(
if self.add_complete:
self._entity.async_write_ha_state()
@callback
def _handle_result(
self, self,
event: Optional[Event], event: Optional[Event],
template: Template, template: Template,
last_result: Optional[str], last_result: Union[str, None, TemplateError],
result: Union[str, TemplateError], result: Union[str, TemplateError],
) -> None: ) -> None:
if event: """Handle a template result event callback."""
self._entity.async_set_context(event.context)
if isinstance(result, TemplateError): if isinstance(result, TemplateError):
_LOGGER.error( _LOGGER.error(
"TemplateError('%s') " "TemplateError('%s') "
@ -83,13 +80,10 @@ class _TemplateAttribute:
self._default_update(result) self._default_update(result)
else: else:
self.on_update(result) self.on_update(result)
self._write_update_if_added()
return return
if not self.validator: if not self.validator:
self.on_update(result) self.on_update(result)
self._write_update_if_added()
return return
try: try:
@ -107,26 +101,10 @@ class _TemplateAttribute:
ex.msg, ex.msg,
) )
self.on_update(None) self.on_update(None)
self._write_update_if_added()
return return
self.on_update(validated) self.on_update(validated)
self._write_update_if_added() return
@callback
def async_template_startup(self) -> None:
"""Call from containing entity when added to hass."""
result_info = async_track_template_result(
self._entity.hass, self.template, self._handle_result
)
self.async_update = result_info.async_refresh
@callback
def _remove_from_hass():
result_info.async_remove()
return _remove_from_hass
class TemplateEntity(Entity): class TemplateEntity(Entity):
@ -141,7 +119,8 @@ class TemplateEntity(Entity):
attribute_templates=None, attribute_templates=None,
): ):
"""Template Entity.""" """Template Entity."""
self._template_attrs = [] self._template_attrs = {}
self._async_update = None
self._attribute_templates = attribute_templates self._attribute_templates = attribute_templates
self._attributes = {} self._attributes = {}
self._availability_template = availability_template self._availability_template = availability_template
@ -233,17 +212,41 @@ class TemplateEntity(Entity):
self, attribute, template, validator, on_update, none_on_template_error self, attribute, template, validator, on_update, none_on_template_error
) )
attribute.async_setup() attribute.async_setup()
self._template_attrs.append(attribute) self._template_attrs.setdefault(template, [])
self._template_attrs[template].append(attribute)
@callback
def _handle_results(
self,
event: Optional[Event],
updates: List[TrackTemplateResult],
) -> None:
"""Call back the results to the attributes."""
if event:
self.async_set_context(event.context)
for update in updates:
for attr in self._template_attrs[update.template]:
attr.handle_result(
event, update.template, update.last_result, update.result
)
if self._async_update:
self.async_write_ha_state()
async def _async_template_startup(self, *_) -> None: async def _async_template_startup(self, *_) -> None:
# async_update will not write state # _handle_results will not write state until "_async_update" is set
# until "add_complete" is set on the attribute template_var_tups = [
for attribute in self._template_attrs: TrackTemplate(template, None) for template in self._template_attrs
self.async_on_remove(attribute.async_template_startup()) ]
await self.async_update()
for attribute in self._template_attrs: result_info = async_track_template_result(
attribute.add_complete = True self.hass, template_var_tups, self._handle_results
)
self.async_on_remove(result_info.async_remove)
result_info.async_refresh()
self.async_write_ha_state() self.async_write_ha_state()
self._async_update = result_info.async_refresh
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass.""" """Run when entity about to be added to hass."""
@ -272,6 +275,4 @@ class TemplateEntity(Entity):
async def async_update(self) -> None: async def async_update(self) -> None:
"""Call for forced update.""" """Call for forced update."""
for attribute in self._template_attrs: self._async_update()
if attribute.async_update:
attribute.async_update()

View File

@ -7,7 +7,11 @@ from homeassistant import exceptions
from homeassistant.const import CONF_FOR, CONF_PLATFORM, CONF_VALUE_TEMPLATE from homeassistant.const import CONF_FOR, CONF_PLATFORM, CONF_VALUE_TEMPLATE
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers import config_validation as cv, template from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.event import async_call_later, async_track_template_result from homeassistant.helpers.event import (
TrackTemplate,
async_call_later,
async_track_template_result,
)
from homeassistant.helpers.template import result_as_boolean from homeassistant.helpers.template import result_as_boolean
# mypy: allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
@ -34,9 +38,10 @@ async def async_attach_trigger(
delay_cancel = None delay_cancel = None
@callback @callback
def template_listener(event, _, last_result, result): def template_listener(event, updates):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
nonlocal delay_cancel nonlocal delay_cancel
result = updates.pop().result
if delay_cancel: if delay_cancel:
# pylint: disable=not-callable # pylint: disable=not-callable
@ -94,7 +99,9 @@ async def async_attach_trigger(
delay_cancel = async_call_later(hass, period.seconds, call_action) delay_cancel = async_call_later(hass, period.seconds, call_action)
info = async_track_template_result( info = async_track_template_result(
hass, value_template, template_listener, automation_info["variables"] hass,
[TrackTemplate(value_template, automation_info["variables"])],
template_listener,
) )
unsub = info.async_remove unsub = info.async_remove

View File

@ -71,6 +71,7 @@ from homeassistant.const import (
from homeassistant.core import EVENT_HOMEASSISTANT_START, callback from homeassistant.core import EVENT_HOMEASSISTANT_START, callback
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.service import async_call_from_config from homeassistant.helpers.service import async_call_from_config
@ -149,8 +150,10 @@ class UniversalMediaPlayer(MediaPlayerEntity):
self.async_schedule_update_ha_state(True) self.async_schedule_update_ha_state(True)
@callback @callback
def _async_on_template_update(event, template, last_result, result): def _async_on_template_update(event, updates):
"""Update ha state when dependencies update.""" """Update ha state when dependencies update."""
result = updates.pop().result
if isinstance(result, TemplateError): if isinstance(result, TemplateError):
self._state_template_result = None self._state_template_result = None
else: else:
@ -158,8 +161,10 @@ class UniversalMediaPlayer(MediaPlayerEntity):
self.async_schedule_update_ha_state(True) self.async_schedule_update_ha_state(True)
if self._state_template is not None: if self._state_template is not None:
result = self.hass.helpers.event.async_track_template_result( result = async_track_template_result(
self._state_template, _async_on_template_update self.hass,
[TrackTemplate(self._state_template, None)],
_async_on_template_update,
) )
self.hass.bus.async_listen_once( self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, callback(lambda _: result.async_refresh()) EVENT_HOMEASSISTANT_START, callback(lambda _: result.async_refresh())

View File

@ -15,7 +15,7 @@ from homeassistant.exceptions import (
Unauthorized, Unauthorized,
) )
from homeassistant.helpers import config_validation as cv, entity from homeassistant.helpers import config_validation as cv, entity
from homeassistant.helpers.event import async_track_template_result from homeassistant.helpers.event import TrackTemplate, async_track_template_result
from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integration
@ -255,19 +255,23 @@ def handle_render_template(hass, connection, msg):
variables = msg.get("variables") variables = msg.get("variables")
@callback @callback
def _template_listener(event, template, last_result, result): def _template_listener(event, updates):
track_template_result = updates.pop()
result = track_template_result.result
if isinstance(result, TemplateError): if isinstance(result, TemplateError):
_LOGGER.error( _LOGGER.error(
"TemplateError('%s') " "while processing template '%s'", "TemplateError('%s') " "while processing template '%s'",
result, result,
template, track_template_result.template,
) )
result = None result = None
connection.send_message(messages.event_message(msg["id"], {"result": result})) connection.send_message(messages.event_message(msg["id"], {"result": result}))
info = async_track_template_result(hass, template, _template_listener, variables) info = async_track_template_result(
hass, [TrackTemplate(template, variables)], _template_listener
)
connection.subscriptions[msg["id"]] = info.async_remove connection.subscriptions[msg["id"]] = info.async_remove

View File

@ -1,14 +1,27 @@
"""Helpers for listening to events.""" """Helpers for listening to events."""
import asyncio import asyncio
from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
import time import time
from typing import Any, Awaitable, Callable, Iterable, Optional, Union from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import attr import attr
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_NOW, ATTR_NOW,
EVENT_CORE_CONFIG_UPDATE, EVENT_CORE_CONFIG_UPDATE,
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
@ -48,6 +61,37 @@ TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@dataclass
class TrackTemplate:
"""Class for keeping track of a template with variables.
The template is template to calculate.
The variables are variables to pass to the template.
"""
template: Template
variables: TemplateVarsType
@dataclass
class TrackTemplateResult:
"""Class for result of template tracking.
template
The template that has changed.
last_result
The output from the template on the last successful run, or None
if no previous successful run.
result
Result from the template run. This will be a string or an
TemplateError if the template resulted in an error.
"""
template: Template
last_result: Union[str, None, TemplateError]
result: Union[str, TemplateError]
def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE: def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE:
"""Convert an async event helper to a threaded one.""" """Convert an async event helper to a threaded one."""
@ -396,13 +440,16 @@ def async_track_template(
""" """
@callback @callback
def state_changed_listener( def _template_changed_listener(
event: Event, event: Event, updates: List[TrackTemplateResult]
template: Template,
last_result: Optional[str],
result: Union[str, TemplateError],
) -> None: ) -> None:
"""Check if condition is correct and run action.""" """Check if condition is correct and run action."""
track_result = updates.pop()
template = track_result.template
last_result = track_result.last_result
result = track_result.result
if isinstance(result, TemplateError): if isinstance(result, TemplateError):
_LOGGER.error( _LOGGER.error(
"Error while processing template: %s", "Error while processing template: %s",
@ -411,7 +458,11 @@ def async_track_template(
) )
return return
if result_as_boolean(last_result) or not result_as_boolean(result): if (
not isinstance(last_result, TemplateError)
and result_as_boolean(last_result)
or not result_as_boolean(result)
):
return return
hass.async_run_job( hass.async_run_job(
@ -422,7 +473,7 @@ def async_track_template(
) )
info = async_track_template_result( info = async_track_template_result(
hass, template, state_changed_listener, variables hass, [TrackTemplate(template, variables)], _template_changed_listener
) )
return info.async_remove return info.async_remove
@ -431,76 +482,89 @@ def async_track_template(
track_template = threaded_listener_factory(async_track_template) track_template = threaded_listener_factory(async_track_template)
_UNCHANGED = object()
class _TrackTemplateResultInfo: class _TrackTemplateResultInfo:
"""Handle removal / refresh of tracker.""" """Handle removal / refresh of tracker."""
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
template: Template, track_templates: Iterable[TrackTemplate],
action: Callable, action: Callable,
variables: Optional[TemplateVarsType],
): ):
"""Handle removal / refresh of tracker init.""" """Handle removal / refresh of tracker init."""
self.hass = hass self.hass = hass
self._template = template
self._template.hass = hass
self._action = action self._action = action
self._variables = variables
self._last_result: Optional[Union[str, TemplateError]] = None for track_template_ in track_templates:
track_template_.template.hass = hass
self._track_templates = track_templates
self._all_listener: Optional[Callable] = None self._all_listener: Optional[Callable] = None
self._domains_listener: Optional[Callable] = None self._domains_listener: Optional[Callable] = None
self._entities_listener: Optional[Callable] = None self._entities_listener: Optional[Callable] = None
self._info: Optional[RenderInfo] = None
self._last_info: Optional[RenderInfo] = None self._last_result: Dict[Template, Union[str, TemplateError]] = {}
self._last_info: Dict[Template, RenderInfo] = {}
self._info: Dict[Template, RenderInfo] = {}
self._last_domains: Set = set()
self._last_entities: Set = set()
def async_setup(self) -> None: def async_setup(self) -> None:
"""Activation of template tracking.""" """Activation of template tracking."""
self._info = self._template.async_render_to_info(self._variables) for track_template_ in self._track_templates:
if self._info.exception: template = track_template_.template
_LOGGER.error( variables = track_template_.variables
"Error while processing template: %s",
self._template.template, self._info[template] = template.async_render_to_info(variables)
exc_info=self._info.exception, if self._info[template].exception:
) _LOGGER.error(
"Error while processing template: %s",
track_template_.template,
exc_info=self._info[template].exception,
)
self._last_info = self._info.copy()
self._create_listeners() self._create_listeners()
self._last_info = self._info
@property @property
def _needs_all_listener(self) -> bool: def _needs_all_listener(self) -> bool:
assert self._info for track_template_ in self._track_templates:
template = track_template_.template
# Tracking all states # Tracking all states
if self._info.all_states: if self._info[template].all_states:
return True return True
# Previous call had an exception # Previous call had an exception
# so we do not know which states # so we do not know which states
# to track # to track
if self._info.exception: if self._info[template].exception:
return True return True
return False return False
@property
def _all_templates_are_static(self) -> bool:
for track_template_ in self._track_templates:
if not self._info[track_template_.template].is_static:
return False
return True
@callback @callback
def _create_listeners(self) -> None: def _create_listeners(self) -> None:
assert self._info if self._all_templates_are_static:
if self._info.is_static:
return return
if self._needs_all_listener: if self._needs_all_listener:
self._setup_all_listener() self._setup_all_listener()
return return
if self._info.domains: self._last_entities, self._last_domains = _entities_domains_from_info(
self._setup_domains_listener() self._info.values()
)
if self._info.entities or self._info.domains: self._setup_domains_listener(self._last_domains)
self._setup_entities_listener() self._setup_entities_listener(self._last_domains, self._last_entities)
@callback @callback
def _cancel_domains_listener(self) -> None: def _cancel_domains_listener(self) -> None:
@ -525,12 +589,11 @@ class _TrackTemplateResultInfo:
@callback @callback
def _update_listeners(self) -> None: def _update_listeners(self) -> None:
assert self._info
assert self._last_info
if self._needs_all_listener: if self._needs_all_listener:
if self._all_listener: if self._all_listener:
return return
self._last_domains = set()
self._last_entities = set()
self._cancel_domains_listener() self._cancel_domains_listener()
self._cancel_entities_listener() self._cancel_entities_listener()
self._setup_all_listener() self._setup_all_listener()
@ -540,27 +603,26 @@ class _TrackTemplateResultInfo:
if had_all_listener: if had_all_listener:
self._cancel_all_listener() self._cancel_all_listener()
domains_changed = self._info.domains != self._last_info.domains entities, domains = _entities_domains_from_info(self._info.values())
domains_changed = domains != self._last_domains
if had_all_listener or domains_changed: if had_all_listener or domains_changed:
domains_changed = True domains_changed = True
self._cancel_domains_listener() self._cancel_domains_listener()
self._setup_domains_listener() self._setup_domains_listener(domains)
if ( if had_all_listener or domains_changed or entities != self._last_entities:
had_all_listener
or domains_changed
or self._info.entities != self._last_info.entities
):
self._cancel_entities_listener() self._cancel_entities_listener()
self._setup_entities_listener() self._setup_entities_listener(domains, entities)
self._last_domains = domains
self._last_entities = entities
@callback @callback
def _setup_entities_listener(self) -> None: def _setup_entities_listener(self, domains: Set, entities: Set) -> None:
assert self._info if domains:
entities = entities.copy()
entities = set(self._info.entities) entities.update(self.hass.states.async_entity_ids(domains))
for entity_id in self.hass.states.async_entity_ids(self._info.domains):
entities.add(entity_id)
# Entities has changed to none # Entities has changed to none
if not entities: if not entities:
@ -571,15 +633,12 @@ class _TrackTemplateResultInfo:
) )
@callback @callback
def _setup_domains_listener(self) -> None: def _setup_domains_listener(self, domains: Set) -> None:
assert self._info if not domains:
# Domains has changed to none
if not self._info.domains:
return return
self._domains_listener = async_track_state_added_domain( self._domains_listener = async_track_state_added_domain(
self.hass, self._info.domains, self._refresh self.hass, domains, self._refresh
) )
@callback @callback
@ -596,40 +655,67 @@ class _TrackTemplateResultInfo:
self._cancel_entities_listener() self._cancel_entities_listener()
@callback @callback
def async_refresh(self, variables: Any = _UNCHANGED) -> None: def async_refresh(self) -> None:
"""Force recalculate the template.""" """Force recalculate the template."""
if variables is not _UNCHANGED:
self._variables = variables
self._refresh(None) self._refresh(None)
@callback @callback
def _refresh(self, event: Optional[Event]) -> None: def _refresh(self, event: Optional[Event]) -> None:
self._info = self._template.async_render_to_info(self._variables) entity_id = event and event.data.get(ATTR_ENTITY_ID)
self._update_listeners() updates = []
self._last_info = self._info info_changed = False
try: for track_template_ in self._track_templates:
result: Union[str, TemplateError] = self._info.result template = track_template_.template
except TemplateError as ex: if (
result = ex entity_id
and len(self._last_info) > 1
and not self._last_info[template].filter_lifecycle(entity_id)
):
continue
# Check to see if the result has changed self._info[template] = template.async_render_to_info(
if result == self._last_result: track_template_.variables
)
info_changed = True
try:
result: Union[str, TemplateError] = self._info[template].result
except TemplateError as ex:
result = ex
last_result = self._last_result.get(template)
# Check to see if the result has changed
if result == last_result:
continue
if isinstance(result, TemplateError) and isinstance(
last_result, TemplateError
):
continue
updates.append(TrackTemplateResult(template, last_result, result))
if info_changed:
self._update_listeners()
self._last_info = self._info.copy()
if not updates:
return return
if isinstance(result, TemplateError) and isinstance( for track_result in updates:
self._last_result, TemplateError self._last_result[track_result.template] = track_result.result
):
return
self.hass.async_run_job( self.hass.async_run_job(self._action, event, updates)
self._action, event, self._template, self._last_result, result
)
self._last_result = result
TrackTemplateResultListener = Callable[ TrackTemplateResultListener = Callable[
[Event, Template, Optional[str], Union[str, TemplateError]], None [
Event,
List[TrackTemplateResult],
],
None,
] ]
"""Type for the listener for template results. """Type for the listener for template results.
@ -638,14 +724,8 @@ TrackTemplateResultListener = Callable[
event event
Event that caused the template to change output. None if not Event that caused the template to change output. None if not
triggered by an event. triggered by an event.
template updates
The template that has changed. A list of TrackTemplateResult
last_result
The output from the template on the last successful run, or None
if no previous successful run.
result
Result from the template run. This will be a string or an
TemplateError if the template resulted in an error.
""" """
@ -653,9 +733,8 @@ TrackTemplateResultListener = Callable[
@bind_hass @bind_hass
def async_track_template_result( def async_track_template_result(
hass: HomeAssistant, hass: HomeAssistant,
template: Template, track_templates: Iterable[TrackTemplate],
action: TrackTemplateResultListener, action: TrackTemplateResultListener,
variables: Optional[TemplateVarsType] = None,
) -> _TrackTemplateResultInfo: ) -> _TrackTemplateResultInfo:
"""Add a listener that fires when a the result of a template changes. """Add a listener that fires when a the result of a template changes.
@ -675,19 +754,18 @@ def async_track_template_result(
---------- ----------
hass hass
Home assistant object. Home assistant object.
template track_templates
The template to calculate. An iterable of TrackTemplate.
action action
Callable to call with results. Callable to call with results.
variables
Variables to pass to the template.
Returns Returns
------- -------
Info object used to unregister the listener, and refresh the template. Info object used to unregister the listener, and refresh the template.
""" """
tracker = _TrackTemplateResultInfo(hass, template, action, variables) tracker = _TrackTemplateResultInfo(hass, track_templates, action)
tracker.async_setup() tracker.async_setup()
return tracker return tracker
@ -1073,3 +1151,16 @@ def process_state_match(
parameter_set = set(parameter) parameter_set = set(parameter)
return lambda state: state in parameter_set return lambda state: state in parameter_set
def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set, Set]:
"""Combine from multiple RenderInfo."""
entities = set()
domains = set()
for render_info in render_infos:
if render_info.entities:
entities.update(render_info.entities)
if render_info.domains:
domains.update(render_info.domains)
return entities, domains

View File

@ -192,7 +192,7 @@ class RenderInfo:
self.entities = frozenset(self.entities) self.entities = frozenset(self.entities)
self.domains = frozenset(self.domains) self.domains = frozenset(self.domains)
if self.all_states: if self.all_states or self.exception:
return return
if not self.domains: if not self.domains:

View File

@ -14,6 +14,8 @@ from homeassistant.core import callback
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from homeassistant.helpers.event import ( from homeassistant.helpers.event import (
TrackTemplate,
TrackTemplateResult,
async_call_later, async_call_later,
async_track_point_in_time, async_track_point_in_time,
async_track_point_in_utc_time, async_track_point_in_utc_time,
@ -581,22 +583,35 @@ async def test_track_template_result(hass):
"{{(states.sensor.test.state|int) + test }}", hass "{{(states.sensor.test.state|int) + test }}", hass
) )
def specific_run_callback(event, template, old_result, new_result): def specific_run_callback(event, updates):
specific_runs.append(int(new_result)) track_result = updates.pop()
specific_runs.append(int(track_result.result))
async_track_template_result(hass, template_condition, specific_run_callback)
@ha.callback
def wildcard_run_callback(event, template, old_result, new_result):
wildcard_runs.append((int(old_result or 0), int(new_result)))
async_track_template_result(hass, template_condition, wildcard_run_callback)
async def wildercard_run_callback(event, template, old_result, new_result):
wildercard_runs.append((int(old_result or 0), int(new_result)))
async_track_template_result( async_track_template_result(
hass, template_condition_var, wildercard_run_callback, {"test": 5} hass, [TrackTemplate(template_condition, None)], specific_run_callback
)
@ha.callback
def wildcard_run_callback(event, updates):
track_result = updates.pop()
wildcard_runs.append(
(int(track_result.last_result or 0), int(track_result.result))
)
async_track_template_result(
hass, [TrackTemplate(template_condition, None)], wildcard_run_callback
)
async def wildercard_run_callback(event, updates):
track_result = updates.pop()
wildercard_runs.append(
(int(track_result.last_result or 0), int(track_result.result))
)
async_track_template_result(
hass,
[TrackTemplate(template_condition_var, {"test": 5})],
wildercard_run_callback,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -661,13 +676,15 @@ async def test_track_template_result_complex(hass):
""" """
template_complex = Template(template_complex_str, hass) template_complex = Template(template_complex_str, hass)
def specific_run_callback(event, template, old_result, new_result): def specific_run_callback(event, updates):
specific_runs.append(new_result) specific_runs.append(updates.pop().result)
hass.states.async_set("light.one", "on") hass.states.async_set("light.one", "on")
hass.states.async_set("lock.one", "locked") hass.states.async_set("lock.one", "locked")
async_track_template_result(hass, template_complex, specific_run_callback) async_track_template_result(
hass, [TrackTemplate(template_complex, None)], specific_run_callback
)
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set("sensor.domain", "light") hass.states.async_set("sensor.domain", "light")
@ -742,14 +759,16 @@ async def test_track_template_result_with_wildcard(hass):
""" """
template_complex = Template(template_complex_str, hass) template_complex = Template(template_complex_str, hass)
def specific_run_callback(event, template, old_result, new_result): def specific_run_callback(event, updates):
specific_runs.append(new_result) specific_runs.append(updates.pop().result)
hass.states.async_set("cover.office_drapes", "closed") hass.states.async_set("cover.office_drapes", "closed")
hass.states.async_set("cover.office_window", "closed") hass.states.async_set("cover.office_window", "closed")
hass.states.async_set("cover.office_skylight", "open") hass.states.async_set("cover.office_skylight", "open")
async_track_template_result(hass, template_complex, specific_run_callback) async_track_template_result(
hass, [TrackTemplate(template_complex, None)], specific_run_callback
)
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set("cover.office_window", "open") hass.states.async_set("cover.office_window", "open")
@ -786,10 +805,12 @@ async def test_track_template_result_with_group(hass):
""" """
template_complex = Template(template_complex_str, hass) template_complex = Template(template_complex_str, hass)
def specific_run_callback(event, template, old_result, new_result): def specific_run_callback(event, updates):
specific_runs.append(new_result) specific_runs.append(updates.pop().result)
async_track_template_result(hass, template_complex, specific_run_callback) async_track_template_result(
hass, [TrackTemplate(template_complex, None)], specific_run_callback
)
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set("sensor.power_1", 100.1) hass.states.async_set("sensor.power_1", 100.1)
@ -827,13 +848,12 @@ async def test_track_template_result_and_conditional(hass):
template = Template(template_str, hass) template = Template(template_str, hass)
def specific_run_callback(event, template, old_result, new_result): def specific_run_callback(event, updates):
import pprint specific_runs.append(updates.pop().result)
pprint.pprint([event, template, old_result, new_result]) async_track_template_result(
specific_runs.append(new_result) hass, [TrackTemplate(template, None)], specific_run_callback
)
async_track_template_result(hass, template, specific_run_callback)
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set("light.b", "on") hass.states.async_set("light.b", "on")
@ -869,21 +889,26 @@ async def test_track_template_result_iterator(hass):
iterator_runs = [] iterator_runs = []
@ha.callback @ha.callback
def iterator_callback(event, template, old_result, new_result): def iterator_callback(event, updates):
iterator_runs.append(new_result) iterator_runs.append(updates.pop().result)
async_track_template_result( async_track_template_result(
hass, hass,
Template( [
""" TrackTemplate(
Template(
"""
{% for state in states.sensor %} {% for state in states.sensor %}
{% if state.state == 'on' %} {% if state.state == 'on' %}
{{ state.entity_id }}, {{ state.entity_id }},
{% endif %} {% endif %}
{% endfor %} {% endfor %}
""", """,
hass, hass,
), ),
None,
)
],
iterator_callback, iterator_callback,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -896,16 +921,21 @@ async def test_track_template_result_iterator(hass):
filter_runs = [] filter_runs = []
@ha.callback @ha.callback
def filter_callback(event, template, old_result, new_result): def filter_callback(event, updates):
filter_runs.append(new_result) filter_runs.append(updates.pop().result)
async_track_template_result( async_track_template_result(
hass, hass,
Template( [
"""{{ states.sensor|selectattr("state","equalto","on") TrackTemplate(
Template(
"""{{ states.sensor|selectattr("state","equalto","on")
|join(",", attribute="entity_id") }}""", |join(",", attribute="entity_id") }}""",
hass, hass,
), ),
None,
)
],
filter_callback, filter_callback,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -931,21 +961,42 @@ async def test_track_template_result_errors(hass, caplog):
syntax_error_runs = [] syntax_error_runs = []
not_exist_runs = [] not_exist_runs = []
def syntax_error_listener(event, template, last_result, result): @ha.callback
syntax_error_runs.append((event, template, last_result, result)) def syntax_error_listener(event, updates):
track_result = updates.pop()
syntax_error_runs.append(
(
event,
track_result.template,
track_result.last_result,
track_result.result,
)
)
async_track_template_result(hass, template_syntax_error, syntax_error_listener) async_track_template_result(
hass, [TrackTemplate(template_syntax_error, None)], syntax_error_listener
)
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(syntax_error_runs) == 0 assert len(syntax_error_runs) == 0
assert "TemplateSyntaxError" in caplog.text assert "TemplateSyntaxError" in caplog.text
@ha.callback
def not_exist_runs_error_listener(event, updates):
template_track = updates.pop()
not_exist_runs.append(
(
event,
template_track.template,
template_track.last_result,
template_track.result,
)
)
async_track_template_result( async_track_template_result(
hass, hass,
template_not_exist, [TrackTemplate(template_not_exist, None)],
lambda event, template, last_result, result: ( not_exist_runs_error_listener,
not_exist_runs.append((event, template, last_result, result))
),
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -990,10 +1041,13 @@ async def test_track_template_result_refresh_cancel(hass):
refresh_runs = [] refresh_runs = []
def refresh_listener(event, template, last_result, result): @ha.callback
refresh_runs.append(result) def refresh_listener(event, updates):
refresh_runs.append(updates.pop().result)
info = async_track_template_result(hass, template_refresh, refresh_listener) info = async_track_template_result(
hass, [TrackTemplate(template_refresh, None)], refresh_listener
)
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set("switch.test", "off") hass.states.async_set("switch.test", "off")
@ -1020,7 +1074,9 @@ async def test_track_template_result_refresh_cancel(hass):
refresh_runs = [] refresh_runs = []
info = async_track_template_result( info = async_track_template_result(
hass, template_refresh, refresh_listener, {"value": "duck"} hass,
[TrackTemplate(template_refresh, {"value": "duck"})],
refresh_listener,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
info.async_refresh() info.async_refresh()
@ -1032,9 +1088,132 @@ async def test_track_template_result_refresh_cancel(hass):
await hass.async_block_till_done() await hass.async_block_till_done()
assert refresh_runs == ["duck"] assert refresh_runs == ["duck"]
info.async_refresh({"value": "dog"})
async def test_async_track_template_result_multiple_templates(hass):
"""Test tracking multiple templates."""
template_1 = Template("{{ states.switch.test.state == 'on' }}")
template_2 = Template("{{ states.switch.test.state == 'on' }}")
template_3 = Template("{{ states.switch.test.state == 'off' }}")
template_4 = Template(
"{{ states.binary_sensor | map(attribute='entity_id') | list }}"
)
refresh_runs = []
@ha.callback
def refresh_listener(event, updates):
refresh_runs.append(updates)
async_track_template_result(
hass,
[
TrackTemplate(template_1, None),
TrackTemplate(template_2, None),
TrackTemplate(template_3, None),
TrackTemplate(template_4, None),
],
refresh_listener,
)
hass.states.async_set("switch.test", "on")
await hass.async_block_till_done() await hass.async_block_till_done()
assert refresh_runs == ["duck", "dog"]
assert refresh_runs == [
[
TrackTemplateResult(template_1, None, "True"),
TrackTemplateResult(template_2, None, "True"),
TrackTemplateResult(template_3, None, "False"),
]
]
refresh_runs = []
hass.states.async_set("switch.test", "off")
await hass.async_block_till_done()
assert refresh_runs == [
[
TrackTemplateResult(template_1, "True", "False"),
TrackTemplateResult(template_2, "True", "False"),
TrackTemplateResult(template_3, "False", "True"),
]
]
refresh_runs = []
hass.states.async_set("binary_sensor.test", "off")
await hass.async_block_till_done()
assert refresh_runs == [
[TrackTemplateResult(template_4, None, "['binary_sensor.test']")]
]
async def test_async_track_template_result_multiple_templates_mixing_domain(hass):
"""Test tracking multiple templates when tracking entities and an entire domain."""
template_1 = Template("{{ states.switch.test.state == 'on' }}")
template_2 = Template("{{ states.switch.test.state == 'on' }}")
template_3 = Template("{{ states.switch.test.state == 'off' }}")
template_4 = Template("{{ states.switch | map(attribute='entity_id') | list }}")
refresh_runs = []
@ha.callback
def refresh_listener(event, updates):
refresh_runs.append(updates)
async_track_template_result(
hass,
[
TrackTemplate(template_1, None),
TrackTemplate(template_2, None),
TrackTemplate(template_3, None),
TrackTemplate(template_4, None),
],
refresh_listener,
)
hass.states.async_set("switch.test", "on")
await hass.async_block_till_done()
assert refresh_runs == [
[
TrackTemplateResult(template_1, None, "True"),
TrackTemplateResult(template_2, None, "True"),
TrackTemplateResult(template_3, None, "False"),
TrackTemplateResult(template_4, None, "['switch.test']"),
]
]
refresh_runs = []
hass.states.async_set("switch.test", "off")
await hass.async_block_till_done()
assert refresh_runs == [
[
TrackTemplateResult(template_1, "True", "False"),
TrackTemplateResult(template_2, "True", "False"),
TrackTemplateResult(template_3, "False", "True"),
]
]
refresh_runs = []
hass.states.async_set("binary_sensor.test", "off")
await hass.async_block_till_done()
assert refresh_runs == []
refresh_runs = []
hass.states.async_set("switch.new", "off")
await hass.async_block_till_done()
assert refresh_runs == [
[
TrackTemplateResult(
template_4, "['switch.test']", "['switch.new', 'switch.test']"
)
]
]
async def test_track_same_state_simple_no_trigger(hass): async def test_track_same_state_simple_no_trigger(hass):