diff --git a/homeassistant/components/bayesian/binary_sensor.py b/homeassistant/components/bayesian/binary_sensor.py index 86f11cda7e16..8d4dab622630 100644 --- a/homeassistant/components/bayesian/binary_sensor.py +++ b/homeassistant/components/bayesian/binary_sensor.py @@ -21,6 +21,7 @@ from homeassistant.exceptions import TemplateError from homeassistant.helpers import condition import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import ( + TrackTemplate, async_track_state_change_event, async_track_template_result, ) @@ -187,7 +188,10 @@ class BayesianBinarySensor(BinarySensorEntity): ) @callback - def _async_template_result_changed(event, template, last_result, result): + def _async_template_result_changed(event, updates): + track_template_result = updates.pop() + template = track_template_result.template + result = track_template_result.result entity = event and event.data.get("entity_id") if isinstance(result, TemplateError): @@ -215,7 +219,9 @@ class BayesianBinarySensor(BinarySensorEntity): for template in self.observations_by_template: info = async_track_template_result( - self.hass, template, _async_template_result_changed + self.hass, + [TrackTemplate(template, None)], + _async_template_result_changed, ) self._callbacks.append(info) diff --git a/homeassistant/components/template/template_entity.py b/homeassistant/components/template/template_entity.py index d63a2866510e..20b0caec3caa 100644 --- a/homeassistant/components/template/template_entity.py +++ b/homeassistant/components/template/template_entity.py @@ -1,7 +1,7 @@ """TemplateEntity utility class.""" import logging -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union import voluptuous as vol @@ -9,7 +9,12 @@ from homeassistant.core import EVENT_HOMEASSISTANT_START, CoreState, callback from homeassistant.exceptions import TemplateError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import Entity -from homeassistant.helpers.event import Event, async_track_template_result +from homeassistant.helpers.event import ( + Event, + TrackTemplate, + TrackTemplateResult, + async_track_template_result, +) from homeassistant.helpers.template import Template, result_as_boolean _LOGGER = logging.getLogger(__name__) @@ -34,7 +39,6 @@ class _TemplateAttribute: self.validator = validator self.on_update = on_update self.async_update = None - self.add_complete = False self.none_on_template_error = none_on_template_error @callback @@ -54,21 +58,14 @@ class _TemplateAttribute: setattr(self._entity, self._attribute, attr_result) @callback - def _write_update_if_added(self): - if self.add_complete: - self._entity.async_write_ha_state() - - @callback - def _handle_result( + def handle_result( self, event: Optional[Event], template: Template, - last_result: Optional[str], + last_result: Union[str, None, TemplateError], result: Union[str, TemplateError], ) -> None: - if event: - self._entity.async_set_context(event.context) - + """Handle a template result event callback.""" if isinstance(result, TemplateError): _LOGGER.error( "TemplateError('%s') " @@ -83,13 +80,10 @@ class _TemplateAttribute: self._default_update(result) else: self.on_update(result) - self._write_update_if_added() - return if not self.validator: self.on_update(result) - self._write_update_if_added() return try: @@ -107,26 +101,10 @@ class _TemplateAttribute: ex.msg, ) self.on_update(None) - self._write_update_if_added() return self.on_update(validated) - self._write_update_if_added() - - @callback - def async_template_startup(self) -> None: - """Call from containing entity when added to hass.""" - result_info = async_track_template_result( - self._entity.hass, self.template, self._handle_result - ) - - self.async_update = result_info.async_refresh - - @callback - def _remove_from_hass(): - result_info.async_remove() - - return _remove_from_hass + return class TemplateEntity(Entity): @@ -141,7 +119,8 @@ class TemplateEntity(Entity): attribute_templates=None, ): """Template Entity.""" - self._template_attrs = [] + self._template_attrs = {} + self._async_update = None self._attribute_templates = attribute_templates self._attributes = {} self._availability_template = availability_template @@ -233,17 +212,41 @@ class TemplateEntity(Entity): self, attribute, template, validator, on_update, none_on_template_error ) attribute.async_setup() - self._template_attrs.append(attribute) + self._template_attrs.setdefault(template, []) + self._template_attrs[template].append(attribute) + + @callback + def _handle_results( + self, + event: Optional[Event], + updates: List[TrackTemplateResult], + ) -> None: + """Call back the results to the attributes.""" + if event: + self.async_set_context(event.context) + + for update in updates: + for attr in self._template_attrs[update.template]: + attr.handle_result( + event, update.template, update.last_result, update.result + ) + + if self._async_update: + self.async_write_ha_state() async def _async_template_startup(self, *_) -> None: - # async_update will not write state - # until "add_complete" is set on the attribute - for attribute in self._template_attrs: - self.async_on_remove(attribute.async_template_startup()) - await self.async_update() - for attribute in self._template_attrs: - attribute.add_complete = True + # _handle_results will not write state until "_async_update" is set + template_var_tups = [ + TrackTemplate(template, None) for template in self._template_attrs + ] + + result_info = async_track_template_result( + self.hass, template_var_tups, self._handle_results + ) + self.async_on_remove(result_info.async_remove) + result_info.async_refresh() self.async_write_ha_state() + self._async_update = result_info.async_refresh async def async_added_to_hass(self) -> None: """Run when entity about to be added to hass.""" @@ -272,6 +275,4 @@ class TemplateEntity(Entity): async def async_update(self) -> None: """Call for forced update.""" - for attribute in self._template_attrs: - if attribute.async_update: - attribute.async_update() + self._async_update() diff --git a/homeassistant/components/template/trigger.py b/homeassistant/components/template/trigger.py index 980faf4d0a8b..5dcee0a7347e 100644 --- a/homeassistant/components/template/trigger.py +++ b/homeassistant/components/template/trigger.py @@ -7,7 +7,11 @@ from homeassistant import exceptions from homeassistant.const import CONF_FOR, CONF_PLATFORM, CONF_VALUE_TEMPLATE from homeassistant.core import callback from homeassistant.helpers import config_validation as cv, template -from homeassistant.helpers.event import async_call_later, async_track_template_result +from homeassistant.helpers.event import ( + TrackTemplate, + async_call_later, + async_track_template_result, +) from homeassistant.helpers.template import result_as_boolean # mypy: allow-untyped-defs, no-check-untyped-defs @@ -34,9 +38,10 @@ async def async_attach_trigger( delay_cancel = None @callback - def template_listener(event, _, last_result, result): + def template_listener(event, updates): """Listen for state changes and calls action.""" nonlocal delay_cancel + result = updates.pop().result if delay_cancel: # pylint: disable=not-callable @@ -94,7 +99,9 @@ async def async_attach_trigger( delay_cancel = async_call_later(hass, period.seconds, call_action) info = async_track_template_result( - hass, value_template, template_listener, automation_info["variables"] + hass, + [TrackTemplate(value_template, automation_info["variables"])], + template_listener, ) unsub = info.async_remove diff --git a/homeassistant/components/universal/media_player.py b/homeassistant/components/universal/media_player.py index 7d1ac9953b63..c38afc139cf6 100644 --- a/homeassistant/components/universal/media_player.py +++ b/homeassistant/components/universal/media_player.py @@ -71,6 +71,7 @@ from homeassistant.const import ( from homeassistant.core import EVENT_HOMEASSISTANT_START, callback from homeassistant.exceptions import TemplateError from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.event import TrackTemplate, async_track_template_result from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.service import async_call_from_config @@ -149,8 +150,10 @@ class UniversalMediaPlayer(MediaPlayerEntity): self.async_schedule_update_ha_state(True) @callback - def _async_on_template_update(event, template, last_result, result): + def _async_on_template_update(event, updates): """Update ha state when dependencies update.""" + result = updates.pop().result + if isinstance(result, TemplateError): self._state_template_result = None else: @@ -158,8 +161,10 @@ class UniversalMediaPlayer(MediaPlayerEntity): self.async_schedule_update_ha_state(True) if self._state_template is not None: - result = self.hass.helpers.event.async_track_template_result( - self._state_template, _async_on_template_update + result = async_track_template_result( + self.hass, + [TrackTemplate(self._state_template, None)], + _async_on_template_update, ) self.hass.bus.async_listen_once( EVENT_HOMEASSISTANT_START, callback(lambda _: result.async_refresh()) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 4ed0292a9f4f..04ad0ae3d3ab 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -15,7 +15,7 @@ from homeassistant.exceptions import ( Unauthorized, ) from homeassistant.helpers import config_validation as cv, entity -from homeassistant.helpers.event import async_track_template_result +from homeassistant.helpers.event import TrackTemplate, async_track_template_result from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.loader import IntegrationNotFound, async_get_integration @@ -255,19 +255,23 @@ def handle_render_template(hass, connection, msg): variables = msg.get("variables") @callback - def _template_listener(event, template, last_result, result): + def _template_listener(event, updates): + track_template_result = updates.pop() + result = track_template_result.result if isinstance(result, TemplateError): _LOGGER.error( "TemplateError('%s') " "while processing template '%s'", result, - template, + track_template_result.template, ) result = None connection.send_message(messages.event_message(msg["id"], {"result": result})) - info = async_track_template_result(hass, template, _template_listener, variables) + info = async_track_template_result( + hass, [TrackTemplate(template, variables)], _template_listener + ) connection.subscriptions[msg["id"]] = info.async_remove diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 0530ad7cdc77..55bb68fc6ecc 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -1,14 +1,27 @@ """Helpers for listening to events.""" import asyncio +from dataclasses import dataclass from datetime import datetime, timedelta import functools as ft import logging import time -from typing import Any, Awaitable, Callable, Iterable, Optional, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr from homeassistant.const import ( + ATTR_ENTITY_ID, ATTR_NOW, EVENT_CORE_CONFIG_UPDATE, EVENT_STATE_CHANGED, @@ -48,6 +61,37 @@ TRACK_ENTITY_REGISTRY_UPDATED_LISTENER = "track_entity_registry_updated_listener _LOGGER = logging.getLogger(__name__) +@dataclass +class TrackTemplate: + """Class for keeping track of a template with variables. + + The template is template to calculate. + The variables are variables to pass to the template. + """ + + template: Template + variables: TemplateVarsType + + +@dataclass +class TrackTemplateResult: + """Class for result of template tracking. + + template + The template that has changed. + last_result + The output from the template on the last successful run, or None + if no previous successful run. + result + Result from the template run. This will be a string or an + TemplateError if the template resulted in an error. + """ + + template: Template + last_result: Union[str, None, TemplateError] + result: Union[str, TemplateError] + + def threaded_listener_factory(async_factory: Callable[..., Any]) -> CALLBACK_TYPE: """Convert an async event helper to a threaded one.""" @@ -396,13 +440,16 @@ def async_track_template( """ @callback - def state_changed_listener( - event: Event, - template: Template, - last_result: Optional[str], - result: Union[str, TemplateError], + def _template_changed_listener( + event: Event, updates: List[TrackTemplateResult] ) -> None: """Check if condition is correct and run action.""" + track_result = updates.pop() + + template = track_result.template + last_result = track_result.last_result + result = track_result.result + if isinstance(result, TemplateError): _LOGGER.error( "Error while processing template: %s", @@ -411,7 +458,11 @@ def async_track_template( ) return - if result_as_boolean(last_result) or not result_as_boolean(result): + if ( + not isinstance(last_result, TemplateError) + and result_as_boolean(last_result) + or not result_as_boolean(result) + ): return hass.async_run_job( @@ -422,7 +473,7 @@ def async_track_template( ) info = async_track_template_result( - hass, template, state_changed_listener, variables + hass, [TrackTemplate(template, variables)], _template_changed_listener ) return info.async_remove @@ -431,76 +482,89 @@ def async_track_template( track_template = threaded_listener_factory(async_track_template) -_UNCHANGED = object() - - class _TrackTemplateResultInfo: """Handle removal / refresh of tracker.""" def __init__( self, hass: HomeAssistant, - template: Template, + track_templates: Iterable[TrackTemplate], action: Callable, - variables: Optional[TemplateVarsType], ): """Handle removal / refresh of tracker init.""" self.hass = hass - self._template = template - self._template.hass = hass self._action = action - self._variables = variables - self._last_result: Optional[Union[str, TemplateError]] = None + + for track_template_ in track_templates: + track_template_.template.hass = hass + self._track_templates = track_templates + self._all_listener: Optional[Callable] = None self._domains_listener: Optional[Callable] = None self._entities_listener: Optional[Callable] = None - self._info: Optional[RenderInfo] = None - self._last_info: Optional[RenderInfo] = None + + self._last_result: Dict[Template, Union[str, TemplateError]] = {} + self._last_info: Dict[Template, RenderInfo] = {} + self._info: Dict[Template, RenderInfo] = {} + self._last_domains: Set = set() + self._last_entities: Set = set() def async_setup(self) -> None: """Activation of template tracking.""" - self._info = self._template.async_render_to_info(self._variables) - if self._info.exception: - _LOGGER.error( - "Error while processing template: %s", - self._template.template, - exc_info=self._info.exception, - ) + for track_template_ in self._track_templates: + template = track_template_.template + variables = track_template_.variables + + self._info[template] = template.async_render_to_info(variables) + if self._info[template].exception: + _LOGGER.error( + "Error while processing template: %s", + track_template_.template, + exc_info=self._info[template].exception, + ) + + self._last_info = self._info.copy() self._create_listeners() - self._last_info = self._info @property def _needs_all_listener(self) -> bool: - assert self._info + for track_template_ in self._track_templates: + template = track_template_.template - # Tracking all states - if self._info.all_states: - return True + # Tracking all states + if self._info[template].all_states: + return True - # Previous call had an exception - # so we do not know which states - # to track - if self._info.exception: - return True + # Previous call had an exception + # so we do not know which states + # to track + if self._info[template].exception: + return True return False + @property + def _all_templates_are_static(self) -> bool: + for track_template_ in self._track_templates: + if not self._info[track_template_.template].is_static: + return False + + return True + @callback def _create_listeners(self) -> None: - assert self._info - - if self._info.is_static: + if self._all_templates_are_static: return if self._needs_all_listener: self._setup_all_listener() return - if self._info.domains: - self._setup_domains_listener() - - if self._info.entities or self._info.domains: - self._setup_entities_listener() + self._last_entities, self._last_domains = _entities_domains_from_info( + self._info.values() + ) + self._setup_domains_listener(self._last_domains) + self._setup_entities_listener(self._last_domains, self._last_entities) @callback def _cancel_domains_listener(self) -> None: @@ -525,12 +589,11 @@ class _TrackTemplateResultInfo: @callback def _update_listeners(self) -> None: - assert self._info - assert self._last_info - if self._needs_all_listener: if self._all_listener: return + self._last_domains = set() + self._last_entities = set() self._cancel_domains_listener() self._cancel_entities_listener() self._setup_all_listener() @@ -540,27 +603,26 @@ class _TrackTemplateResultInfo: if had_all_listener: self._cancel_all_listener() - domains_changed = self._info.domains != self._last_info.domains + entities, domains = _entities_domains_from_info(self._info.values()) + domains_changed = domains != self._last_domains + if had_all_listener or domains_changed: domains_changed = True self._cancel_domains_listener() - self._setup_domains_listener() + self._setup_domains_listener(domains) - if ( - had_all_listener - or domains_changed - or self._info.entities != self._last_info.entities - ): + if had_all_listener or domains_changed or entities != self._last_entities: self._cancel_entities_listener() - self._setup_entities_listener() + self._setup_entities_listener(domains, entities) + + self._last_domains = domains + self._last_entities = entities @callback - def _setup_entities_listener(self) -> None: - assert self._info - - entities = set(self._info.entities) - for entity_id in self.hass.states.async_entity_ids(self._info.domains): - entities.add(entity_id) + def _setup_entities_listener(self, domains: Set, entities: Set) -> None: + if domains: + entities = entities.copy() + entities.update(self.hass.states.async_entity_ids(domains)) # Entities has changed to none if not entities: @@ -571,15 +633,12 @@ class _TrackTemplateResultInfo: ) @callback - def _setup_domains_listener(self) -> None: - assert self._info - - # Domains has changed to none - if not self._info.domains: + def _setup_domains_listener(self, domains: Set) -> None: + if not domains: return self._domains_listener = async_track_state_added_domain( - self.hass, self._info.domains, self._refresh + self.hass, domains, self._refresh ) @callback @@ -596,40 +655,67 @@ class _TrackTemplateResultInfo: self._cancel_entities_listener() @callback - def async_refresh(self, variables: Any = _UNCHANGED) -> None: + def async_refresh(self) -> None: """Force recalculate the template.""" - if variables is not _UNCHANGED: - self._variables = variables self._refresh(None) @callback def _refresh(self, event: Optional[Event]) -> None: - self._info = self._template.async_render_to_info(self._variables) - self._update_listeners() - self._last_info = self._info + entity_id = event and event.data.get(ATTR_ENTITY_ID) + updates = [] + info_changed = False - try: - result: Union[str, TemplateError] = self._info.result - except TemplateError as ex: - result = ex + for track_template_ in self._track_templates: + template = track_template_.template + if ( + entity_id + and len(self._last_info) > 1 + and not self._last_info[template].filter_lifecycle(entity_id) + ): + continue - # Check to see if the result has changed - if result == self._last_result: + self._info[template] = template.async_render_to_info( + track_template_.variables + ) + info_changed = True + + try: + result: Union[str, TemplateError] = self._info[template].result + except TemplateError as ex: + result = ex + + last_result = self._last_result.get(template) + + # Check to see if the result has changed + if result == last_result: + continue + + if isinstance(result, TemplateError) and isinstance( + last_result, TemplateError + ): + continue + + updates.append(TrackTemplateResult(template, last_result, result)) + + if info_changed: + self._update_listeners() + self._last_info = self._info.copy() + + if not updates: return - if isinstance(result, TemplateError) and isinstance( - self._last_result, TemplateError - ): - return + for track_result in updates: + self._last_result[track_result.template] = track_result.result - self.hass.async_run_job( - self._action, event, self._template, self._last_result, result - ) - self._last_result = result + self.hass.async_run_job(self._action, event, updates) TrackTemplateResultListener = Callable[ - [Event, Template, Optional[str], Union[str, TemplateError]], None + [ + Event, + List[TrackTemplateResult], + ], + None, ] """Type for the listener for template results. @@ -638,14 +724,8 @@ TrackTemplateResultListener = Callable[ event Event that caused the template to change output. None if not triggered by an event. - template - The template that has changed. - last_result - The output from the template on the last successful run, or None - if no previous successful run. - result - Result from the template run. This will be a string or an - TemplateError if the template resulted in an error. + updates + A list of TrackTemplateResult """ @@ -653,9 +733,8 @@ TrackTemplateResultListener = Callable[ @bind_hass def async_track_template_result( hass: HomeAssistant, - template: Template, + track_templates: Iterable[TrackTemplate], action: TrackTemplateResultListener, - variables: Optional[TemplateVarsType] = None, ) -> _TrackTemplateResultInfo: """Add a listener that fires when a the result of a template changes. @@ -675,19 +754,18 @@ def async_track_template_result( ---------- hass Home assistant object. - template - The template to calculate. + track_templates + An iterable of TrackTemplate. + action Callable to call with results. - variables - Variables to pass to the template. Returns ------- Info object used to unregister the listener, and refresh the template. """ - tracker = _TrackTemplateResultInfo(hass, template, action, variables) + tracker = _TrackTemplateResultInfo(hass, track_templates, action) tracker.async_setup() return tracker @@ -1073,3 +1151,16 @@ def process_state_match( parameter_set = set(parameter) return lambda state: state in parameter_set + + +def _entities_domains_from_info(render_infos: Iterable[RenderInfo]) -> Tuple[Set, Set]: + """Combine from multiple RenderInfo.""" + entities = set() + domains = set() + + for render_info in render_infos: + if render_info.entities: + entities.update(render_info.entities) + if render_info.domains: + domains.update(render_info.domains) + return entities, domains diff --git a/homeassistant/helpers/template.py b/homeassistant/helpers/template.py index b9dc854cd2be..4d559a57c1f4 100644 --- a/homeassistant/helpers/template.py +++ b/homeassistant/helpers/template.py @@ -192,7 +192,7 @@ class RenderInfo: self.entities = frozenset(self.entities) self.domains = frozenset(self.domains) - if self.all_states: + if self.all_states or self.exception: return if not self.domains: diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 6fb422e03e75..41d252177a44 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -14,6 +14,8 @@ from homeassistant.core import callback from homeassistant.exceptions import TemplateError from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.event import ( + TrackTemplate, + TrackTemplateResult, async_call_later, async_track_point_in_time, async_track_point_in_utc_time, @@ -581,22 +583,35 @@ async def test_track_template_result(hass): "{{(states.sensor.test.state|int) + test }}", hass ) - def specific_run_callback(event, template, old_result, new_result): - specific_runs.append(int(new_result)) - - async_track_template_result(hass, template_condition, specific_run_callback) - - @ha.callback - def wildcard_run_callback(event, template, old_result, new_result): - wildcard_runs.append((int(old_result or 0), int(new_result))) - - async_track_template_result(hass, template_condition, wildcard_run_callback) - - async def wildercard_run_callback(event, template, old_result, new_result): - wildercard_runs.append((int(old_result or 0), int(new_result))) + def specific_run_callback(event, updates): + track_result = updates.pop() + specific_runs.append(int(track_result.result)) async_track_template_result( - hass, template_condition_var, wildercard_run_callback, {"test": 5} + hass, [TrackTemplate(template_condition, None)], specific_run_callback + ) + + @ha.callback + def wildcard_run_callback(event, updates): + track_result = updates.pop() + wildcard_runs.append( + (int(track_result.last_result or 0), int(track_result.result)) + ) + + async_track_template_result( + hass, [TrackTemplate(template_condition, None)], wildcard_run_callback + ) + + async def wildercard_run_callback(event, updates): + track_result = updates.pop() + wildercard_runs.append( + (int(track_result.last_result or 0), int(track_result.result)) + ) + + async_track_template_result( + hass, + [TrackTemplate(template_condition_var, {"test": 5})], + wildercard_run_callback, ) await hass.async_block_till_done() @@ -661,13 +676,15 @@ async def test_track_template_result_complex(hass): """ template_complex = Template(template_complex_str, hass) - def specific_run_callback(event, template, old_result, new_result): - specific_runs.append(new_result) + def specific_run_callback(event, updates): + specific_runs.append(updates.pop().result) hass.states.async_set("light.one", "on") hass.states.async_set("lock.one", "locked") - async_track_template_result(hass, template_complex, specific_run_callback) + async_track_template_result( + hass, [TrackTemplate(template_complex, None)], specific_run_callback + ) await hass.async_block_till_done() hass.states.async_set("sensor.domain", "light") @@ -742,14 +759,16 @@ async def test_track_template_result_with_wildcard(hass): """ template_complex = Template(template_complex_str, hass) - def specific_run_callback(event, template, old_result, new_result): - specific_runs.append(new_result) + def specific_run_callback(event, updates): + specific_runs.append(updates.pop().result) hass.states.async_set("cover.office_drapes", "closed") hass.states.async_set("cover.office_window", "closed") hass.states.async_set("cover.office_skylight", "open") - async_track_template_result(hass, template_complex, specific_run_callback) + async_track_template_result( + hass, [TrackTemplate(template_complex, None)], specific_run_callback + ) await hass.async_block_till_done() hass.states.async_set("cover.office_window", "open") @@ -786,10 +805,12 @@ async def test_track_template_result_with_group(hass): """ template_complex = Template(template_complex_str, hass) - def specific_run_callback(event, template, old_result, new_result): - specific_runs.append(new_result) + def specific_run_callback(event, updates): + specific_runs.append(updates.pop().result) - async_track_template_result(hass, template_complex, specific_run_callback) + async_track_template_result( + hass, [TrackTemplate(template_complex, None)], specific_run_callback + ) await hass.async_block_till_done() hass.states.async_set("sensor.power_1", 100.1) @@ -827,13 +848,12 @@ async def test_track_template_result_and_conditional(hass): template = Template(template_str, hass) - def specific_run_callback(event, template, old_result, new_result): - import pprint + def specific_run_callback(event, updates): + specific_runs.append(updates.pop().result) - pprint.pprint([event, template, old_result, new_result]) - specific_runs.append(new_result) - - async_track_template_result(hass, template, specific_run_callback) + async_track_template_result( + hass, [TrackTemplate(template, None)], specific_run_callback + ) await hass.async_block_till_done() hass.states.async_set("light.b", "on") @@ -869,21 +889,26 @@ async def test_track_template_result_iterator(hass): iterator_runs = [] @ha.callback - def iterator_callback(event, template, old_result, new_result): - iterator_runs.append(new_result) + def iterator_callback(event, updates): + iterator_runs.append(updates.pop().result) async_track_template_result( hass, - Template( - """ + [ + TrackTemplate( + Template( + """ {% for state in states.sensor %} {% if state.state == 'on' %} {{ state.entity_id }}, {% endif %} {% endfor %} """, - hass, - ), + hass, + ), + None, + ) + ], iterator_callback, ) await hass.async_block_till_done() @@ -896,16 +921,21 @@ async def test_track_template_result_iterator(hass): filter_runs = [] @ha.callback - def filter_callback(event, template, old_result, new_result): - filter_runs.append(new_result) + def filter_callback(event, updates): + filter_runs.append(updates.pop().result) async_track_template_result( hass, - Template( - """{{ states.sensor|selectattr("state","equalto","on") + [ + TrackTemplate( + Template( + """{{ states.sensor|selectattr("state","equalto","on") |join(",", attribute="entity_id") }}""", - hass, - ), + hass, + ), + None, + ) + ], filter_callback, ) await hass.async_block_till_done() @@ -931,21 +961,42 @@ async def test_track_template_result_errors(hass, caplog): syntax_error_runs = [] not_exist_runs = [] - def syntax_error_listener(event, template, last_result, result): - syntax_error_runs.append((event, template, last_result, result)) + @ha.callback + def syntax_error_listener(event, updates): + track_result = updates.pop() + syntax_error_runs.append( + ( + event, + track_result.template, + track_result.last_result, + track_result.result, + ) + ) - async_track_template_result(hass, template_syntax_error, syntax_error_listener) + async_track_template_result( + hass, [TrackTemplate(template_syntax_error, None)], syntax_error_listener + ) await hass.async_block_till_done() assert len(syntax_error_runs) == 0 assert "TemplateSyntaxError" in caplog.text + @ha.callback + def not_exist_runs_error_listener(event, updates): + template_track = updates.pop() + not_exist_runs.append( + ( + event, + template_track.template, + template_track.last_result, + template_track.result, + ) + ) + async_track_template_result( hass, - template_not_exist, - lambda event, template, last_result, result: ( - not_exist_runs.append((event, template, last_result, result)) - ), + [TrackTemplate(template_not_exist, None)], + not_exist_runs_error_listener, ) await hass.async_block_till_done() @@ -990,10 +1041,13 @@ async def test_track_template_result_refresh_cancel(hass): refresh_runs = [] - def refresh_listener(event, template, last_result, result): - refresh_runs.append(result) + @ha.callback + def refresh_listener(event, updates): + refresh_runs.append(updates.pop().result) - info = async_track_template_result(hass, template_refresh, refresh_listener) + info = async_track_template_result( + hass, [TrackTemplate(template_refresh, None)], refresh_listener + ) await hass.async_block_till_done() hass.states.async_set("switch.test", "off") @@ -1020,7 +1074,9 @@ async def test_track_template_result_refresh_cancel(hass): refresh_runs = [] info = async_track_template_result( - hass, template_refresh, refresh_listener, {"value": "duck"} + hass, + [TrackTemplate(template_refresh, {"value": "duck"})], + refresh_listener, ) await hass.async_block_till_done() info.async_refresh() @@ -1032,9 +1088,132 @@ async def test_track_template_result_refresh_cancel(hass): await hass.async_block_till_done() assert refresh_runs == ["duck"] - info.async_refresh({"value": "dog"}) + +async def test_async_track_template_result_multiple_templates(hass): + """Test tracking multiple templates.""" + + template_1 = Template("{{ states.switch.test.state == 'on' }}") + template_2 = Template("{{ states.switch.test.state == 'on' }}") + template_3 = Template("{{ states.switch.test.state == 'off' }}") + template_4 = Template( + "{{ states.binary_sensor | map(attribute='entity_id') | list }}" + ) + + refresh_runs = [] + + @ha.callback + def refresh_listener(event, updates): + refresh_runs.append(updates) + + async_track_template_result( + hass, + [ + TrackTemplate(template_1, None), + TrackTemplate(template_2, None), + TrackTemplate(template_3, None), + TrackTemplate(template_4, None), + ], + refresh_listener, + ) + + hass.states.async_set("switch.test", "on") await hass.async_block_till_done() - assert refresh_runs == ["duck", "dog"] + + assert refresh_runs == [ + [ + TrackTemplateResult(template_1, None, "True"), + TrackTemplateResult(template_2, None, "True"), + TrackTemplateResult(template_3, None, "False"), + ] + ] + + refresh_runs = [] + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert refresh_runs == [ + [ + TrackTemplateResult(template_1, "True", "False"), + TrackTemplateResult(template_2, "True", "False"), + TrackTemplateResult(template_3, "False", "True"), + ] + ] + + refresh_runs = [] + hass.states.async_set("binary_sensor.test", "off") + await hass.async_block_till_done() + + assert refresh_runs == [ + [TrackTemplateResult(template_4, None, "['binary_sensor.test']")] + ] + + +async def test_async_track_template_result_multiple_templates_mixing_domain(hass): + """Test tracking multiple templates when tracking entities and an entire domain.""" + + template_1 = Template("{{ states.switch.test.state == 'on' }}") + template_2 = Template("{{ states.switch.test.state == 'on' }}") + template_3 = Template("{{ states.switch.test.state == 'off' }}") + template_4 = Template("{{ states.switch | map(attribute='entity_id') | list }}") + + refresh_runs = [] + + @ha.callback + def refresh_listener(event, updates): + refresh_runs.append(updates) + + async_track_template_result( + hass, + [ + TrackTemplate(template_1, None), + TrackTemplate(template_2, None), + TrackTemplate(template_3, None), + TrackTemplate(template_4, None), + ], + refresh_listener, + ) + + hass.states.async_set("switch.test", "on") + await hass.async_block_till_done() + + assert refresh_runs == [ + [ + TrackTemplateResult(template_1, None, "True"), + TrackTemplateResult(template_2, None, "True"), + TrackTemplateResult(template_3, None, "False"), + TrackTemplateResult(template_4, None, "['switch.test']"), + ] + ] + + refresh_runs = [] + hass.states.async_set("switch.test", "off") + await hass.async_block_till_done() + + assert refresh_runs == [ + [ + TrackTemplateResult(template_1, "True", "False"), + TrackTemplateResult(template_2, "True", "False"), + TrackTemplateResult(template_3, "False", "True"), + ] + ] + + refresh_runs = [] + hass.states.async_set("binary_sensor.test", "off") + await hass.async_block_till_done() + + assert refresh_runs == [] + + refresh_runs = [] + hass.states.async_set("switch.new", "off") + await hass.async_block_till_done() + + assert refresh_runs == [ + [ + TrackTemplateResult( + template_4, "['switch.test']", "['switch.new', 'switch.test']" + ) + ] + ] async def test_track_same_state_simple_no_trigger(hass):