Support templating MQTT triggers (#45614)

* Add support for limited templates (no HASS access)

* Pass variables to automation triggers

* Support templates in MQTT triggers

* Spelling

* Handle trigger referenced by variables

* Raise on unsupported function in limited templates

* Validate MQTT trigger schema in MQTT device trigger

* Add trigger_variables to automation config schema

* Don't print stacktrace when setting up trigger throws

* Make pylint happy

* Add trigger_variables to variables

* Add debug prints, document limited template

* Add tests

* Validate MQTT trigger topic early when possible

* Improve valid_subscribe_topic_template
This commit is contained in:
Erik Montnemery 2021-02-08 10:50:38 +01:00 committed by GitHub
parent b9b1caf4d7
commit 047f16772f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 262 additions and 15 deletions

View File

@ -60,6 +60,7 @@ from .const import (
CONF_ACTION, CONF_ACTION,
CONF_INITIAL_STATE, CONF_INITIAL_STATE,
CONF_TRIGGER, CONF_TRIGGER,
CONF_TRIGGER_VARIABLES,
DEFAULT_INITIAL_STATE, DEFAULT_INITIAL_STATE,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
@ -221,6 +222,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
action_script, action_script,
initial_state, initial_state,
variables, variables,
trigger_variables,
): ):
"""Initialize an automation entity.""" """Initialize an automation entity."""
self._id = automation_id self._id = automation_id
@ -236,6 +238,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._referenced_devices: Optional[Set[str]] = None self._referenced_devices: Optional[Set[str]] = None
self._logger = LOGGER self._logger = LOGGER
self._variables: ScriptVariables = variables self._variables: ScriptVariables = variables
self._trigger_variables: ScriptVariables = trigger_variables
@property @property
def name(self): def name(self):
@ -471,6 +474,16 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
def log_cb(level, msg, **kwargs): def log_cb(level, msg, **kwargs):
self._logger.log(level, "%s %s", msg, self._name, **kwargs) self._logger.log(level, "%s %s", msg, self._name, **kwargs)
variables = None
if self._trigger_variables:
try:
variables = self._trigger_variables.async_render(
cast(HomeAssistant, self.hass), None, limited=True
)
except template.TemplateError as err:
self._logger.error("Error rendering trigger variables: %s", err)
return None
return await async_initialize_triggers( return await async_initialize_triggers(
cast(HomeAssistant, self.hass), cast(HomeAssistant, self.hass),
self._trigger_config, self._trigger_config,
@ -479,6 +492,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._name, self._name,
log_cb, log_cb,
home_assistant_start, home_assistant_start,
variables,
) )
@property @property
@ -556,6 +570,18 @@ async def _async_process_config(
else: else:
cond_func = None cond_func = None
# Add trigger variables to variables
variables = None
if CONF_TRIGGER_VARIABLES in config_block:
variables = ScriptVariables(
dict(config_block[CONF_TRIGGER_VARIABLES].as_dict())
)
if CONF_VARIABLES in config_block:
if variables:
variables.variables.update(config_block[CONF_VARIABLES].as_dict())
else:
variables = config_block[CONF_VARIABLES]
entity = AutomationEntity( entity = AutomationEntity(
automation_id, automation_id,
name, name,
@ -563,7 +589,8 @@ async def _async_process_config(
cond_func, cond_func,
action_script, action_script,
initial_state, initial_state,
config_block.get(CONF_VARIABLES), variables,
config_block.get(CONF_TRIGGER_VARIABLES),
) )
entities.append(entity) entities.append(entity)

View File

@ -21,6 +21,7 @@ from .const import (
CONF_HIDE_ENTITY, CONF_HIDE_ENTITY,
CONF_INITIAL_STATE, CONF_INITIAL_STATE,
CONF_TRIGGER, CONF_TRIGGER,
CONF_TRIGGER_VARIABLES,
DOMAIN, DOMAIN,
) )
from .helpers import async_get_blueprints from .helpers import async_get_blueprints
@ -43,6 +44,7 @@ PLATFORM_SCHEMA = vol.All(
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA, vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA, vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA, vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Optional(CONF_TRIGGER_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
}, },
script.SCRIPT_MODE_SINGLE, script.SCRIPT_MODE_SINGLE,

View File

@ -3,6 +3,7 @@ import logging
CONF_ACTION = "action" CONF_ACTION = "action"
CONF_TRIGGER = "trigger" CONF_TRIGGER = "trigger"
CONF_TRIGGER_VARIABLES = "trigger_variables"
DOMAIN = "automation" DOMAIN = "automation"
CONF_DESCRIPTION = "description" CONF_DESCRIPTION = "description"

View File

@ -89,12 +89,14 @@ class TriggerInstance:
async def async_attach_trigger(self): async def async_attach_trigger(self):
"""Attach MQTT trigger.""" """Attach MQTT trigger."""
mqtt_config = { mqtt_config = {
mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN,
mqtt_trigger.CONF_TOPIC: self.trigger.topic, mqtt_trigger.CONF_TOPIC: self.trigger.topic,
mqtt_trigger.CONF_ENCODING: DEFAULT_ENCODING, mqtt_trigger.CONF_ENCODING: DEFAULT_ENCODING,
mqtt_trigger.CONF_QOS: self.trigger.qos, mqtt_trigger.CONF_QOS: self.trigger.qos,
} }
if self.trigger.payload: if self.trigger.payload:
mqtt_config[CONF_PAYLOAD] = self.trigger.payload mqtt_config[CONF_PAYLOAD] = self.trigger.payload
mqtt_config = mqtt_trigger.TRIGGER_SCHEMA(mqtt_config)
if self.remove: if self.remove:
self.remove() self.remove()

View File

@ -1,11 +1,12 @@
"""Offer MQTT listening automation rules.""" """Offer MQTT listening automation rules."""
import json import json
import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_PAYLOAD, CONF_PLATFORM from homeassistant.const import CONF_PAYLOAD, CONF_PLATFORM
from homeassistant.core import HassJob, callback from homeassistant.core import HassJob, callback
import homeassistant.helpers.config_validation as cv from homeassistant.helpers import config_validation as cv, template
from .. import mqtt from .. import mqtt
@ -20,8 +21,8 @@ DEFAULT_QOS = 0
TRIGGER_SCHEMA = vol.Schema( TRIGGER_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_PLATFORM): mqtt.DOMAIN, vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic, vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic_template,
vol.Optional(CONF_PAYLOAD): cv.string, vol.Optional(CONF_PAYLOAD): cv.template,
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string, vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All( vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(
vol.Coerce(int), vol.In([0, 1, 2]) vol.Coerce(int), vol.In([0, 1, 2])
@ -29,6 +30,8 @@ TRIGGER_SCHEMA = vol.Schema(
} }
) )
_LOGGER = logging.getLogger(__name__)
async def async_attach_trigger(hass, config, action, automation_info): async def async_attach_trigger(hass, config, action, automation_info):
"""Listen for state changes based on configuration.""" """Listen for state changes based on configuration."""
@ -37,6 +40,18 @@ async def async_attach_trigger(hass, config, action, automation_info):
encoding = config[CONF_ENCODING] or None encoding = config[CONF_ENCODING] or None
qos = config[CONF_QOS] qos = config[CONF_QOS]
job = HassJob(action) job = HassJob(action)
variables = None
if automation_info:
variables = automation_info.get("variables")
template.attach(hass, payload)
if payload:
payload = payload.async_render(variables, limited=True)
template.attach(hass, topic)
if isinstance(topic, template.Template):
topic = topic.async_render(variables, limited=True)
topic = mqtt.util.valid_subscribe_topic(topic)
@callback @callback
def mqtt_automation_listener(mqttmsg): def mqtt_automation_listener(mqttmsg):
@ -57,6 +72,10 @@ async def async_attach_trigger(hass, config, action, automation_info):
hass.async_run_hass_job(job, {"trigger": data}) hass.async_run_hass_job(job, {"trigger": data})
_LOGGER.debug(
"Attaching MQTT trigger for topic: '%s', payload: '%s'", topic, payload
)
remove = await mqtt.async_subscribe( remove = await mqtt.async_subscribe(
hass, topic, mqtt_automation_listener, encoding=encoding, qos=qos hass, topic, mqtt_automation_listener, encoding=encoding, qos=qos
) )

View File

@ -4,7 +4,7 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_PAYLOAD from homeassistant.const import CONF_PAYLOAD
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv, template
from .const import ( from .const import (
ATTR_PAYLOAD, ATTR_PAYLOAD,
@ -61,6 +61,16 @@ def valid_subscribe_topic(value: Any) -> str:
return value return value
def valid_subscribe_topic_template(value: Any) -> template.Template:
"""Validate either a jinja2 template or a valid MQTT subscription topic."""
tpl = template.Template(value)
if tpl.is_static:
valid_subscribe_topic(value)
return tpl
def valid_publish_topic(value: Any) -> str: def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic.""" """Validate that we can publish using this MQTT topic."""
value = valid_topic(value) value = valid_topic(value)

View File

@ -572,7 +572,7 @@ def dynamic_template(value: Optional[Any]) -> template_helper.Template:
if isinstance(value, (list, dict, template_helper.Template)): if isinstance(value, (list, dict, template_helper.Template)):
raise vol.Invalid("template value should be a string") raise vol.Invalid("template value should be a string")
if not template_helper.is_template_string(str(value)): if not template_helper.is_template_string(str(value)):
raise vol.Invalid("template value does not contain a dynmamic template") raise vol.Invalid("template value does not contain a dynamic template")
template_value = template_helper.Template(str(value)) # type: ignore template_value = template_helper.Template(str(value)) # type: ignore
try: try:

View File

@ -21,6 +21,7 @@ class ScriptVariables:
run_variables: Optional[Mapping[str, Any]], run_variables: Optional[Mapping[str, Any]],
*, *,
render_as_defaults: bool = True, render_as_defaults: bool = True,
limited: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Render script variables. """Render script variables.
@ -55,7 +56,9 @@ class ScriptVariables:
if render_as_defaults and key in rendered_variables: if render_as_defaults and key in rendered_variables:
continue continue
rendered_variables[key] = template.render_complex(value, rendered_variables) rendered_variables[key] = template.render_complex(
value, rendered_variables, limited
)
return rendered_variables return rendered_variables

View File

@ -84,7 +84,9 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
obj.hass = hass obj.hass = hass
def render_complex(value: Any, variables: TemplateVarsType = None) -> Any: def render_complex(
value: Any, variables: TemplateVarsType = None, limited: bool = False
) -> Any:
"""Recursive template creator helper function.""" """Recursive template creator helper function."""
if isinstance(value, list): if isinstance(value, list):
return [render_complex(item, variables) for item in value] return [render_complex(item, variables) for item in value]
@ -94,7 +96,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
for key, item in value.items() for key, item in value.items()
} }
if isinstance(value, Template): if isinstance(value, Template):
return value.async_render(variables) return value.async_render(variables, limited=limited)
return value return value
@ -279,6 +281,7 @@ class Template:
"is_static", "is_static",
"_compiled_code", "_compiled_code",
"_compiled", "_compiled",
"_limited",
) )
def __init__(self, template, hass=None): def __init__(self, template, hass=None):
@ -291,10 +294,11 @@ class Template:
self._compiled: Optional[Template] = None self._compiled: Optional[Template] = None
self.hass = hass self.hass = hass
self.is_static = not is_template_string(template) self.is_static = not is_template_string(template)
self._limited = None
@property @property
def _env(self) -> "TemplateEnvironment": def _env(self) -> "TemplateEnvironment":
if self.hass is None: if self.hass is None or self._limited:
return _NO_HASS_ENV return _NO_HASS_ENV
ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT) ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT)
if ret is None: if ret is None:
@ -315,9 +319,13 @@ class Template:
self, self,
variables: TemplateVarsType = None, variables: TemplateVarsType = None,
parse_result: bool = True, parse_result: bool = True,
limited: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Render given template.""" """Render given template.
If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine.
"""
if self.is_static: if self.is_static:
if self.hass.config.legacy_templates or not parse_result: if self.hass.config.legacy_templates or not parse_result:
return self.template return self.template
@ -325,7 +333,7 @@ class Template:
return run_callback_threadsafe( return run_callback_threadsafe(
self.hass.loop, self.hass.loop,
partial(self.async_render, variables, parse_result, **kwargs), partial(self.async_render, variables, parse_result, limited, **kwargs),
).result() ).result()
@callback @callback
@ -333,18 +341,21 @@ class Template:
self, self,
variables: TemplateVarsType = None, variables: TemplateVarsType = None,
parse_result: bool = True, parse_result: bool = True,
limited: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Render given template. """Render given template.
This method must be run in the event loop. This method must be run in the event loop.
If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine.
""" """
if self.is_static: if self.is_static:
if self.hass.config.legacy_templates or not parse_result: if self.hass.config.legacy_templates or not parse_result:
return self.template return self.template
return self._parse_result(self.template) return self._parse_result(self.template)
compiled = self._compiled or self._ensure_compiled() compiled = self._compiled or self._ensure_compiled(limited)
if variables is not None: if variables is not None:
kwargs.update(variables) kwargs.update(variables)
@ -519,12 +530,16 @@ class Template:
) )
return value if error_value is _SENTINEL else error_value return value if error_value is _SENTINEL else error_value
def _ensure_compiled(self) -> "Template": def _ensure_compiled(self, limited: bool = False) -> "Template":
"""Bind a template to a specific hass instance.""" """Bind a template to a specific hass instance."""
self.ensure_valid() self.ensure_valid()
assert self.hass is not None, "hass variable not set on template" assert self.hass is not None, "hass variable not set on template"
assert (
self._limited is None or self._limited == limited
), "can't change between limited and non limited template"
self._limited = limited
env = self._env env = self._env
self._compiled = cast( self._compiled = cast(
@ -1352,6 +1367,31 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
self.globals["strptime"] = strptime self.globals["strptime"] = strptime
self.globals["urlencode"] = urlencode self.globals["urlencode"] = urlencode
if hass is None: if hass is None:
def unsupported(name):
def warn_unsupported(*args, **kwargs):
raise TemplateError(
f"Use of '{name}' is not supported in limited templates"
)
return warn_unsupported
hass_globals = [
"closest",
"distance",
"expand",
"is_state",
"is_state_attr",
"state_attr",
"states",
"utcnow",
"now",
]
hass_filters = ["closest", "expand"]
for glob in hass_globals:
self.globals[glob] = unsupported(glob)
for filt in hass_filters:
self.filters[filt] = unsupported(filt)
return return
# We mark these as a context functions to ensure they get # We mark these as a context functions to ensure they get

View File

@ -8,6 +8,7 @@ import voluptuous as vol
from homeassistant.const import CONF_PLATFORM from homeassistant.const import CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, callback from homeassistant.core import CALLBACK_TYPE, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType, HomeAssistantType from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.loader import IntegrationNotFound, async_get_integration
@ -79,7 +80,9 @@ async def async_initialize_triggers(
removes = [] removes = []
for result in attach_results: for result in attach_results:
if isinstance(result, Exception): if isinstance(result, HomeAssistantError):
log_cb(logging.ERROR, f"Got error '{result}' when setting up triggers for")
elif isinstance(result, Exception):
log_cb(logging.ERROR, "Error setting up trigger", exc_info=result) log_cb(logging.ERROR, "Error setting up trigger", exc_info=result)
elif result is None: elif result is None:
log_cb( log_cb(

View File

@ -1237,6 +1237,94 @@ async def test_automation_variables(hass, caplog):
assert len(calls) == 3 assert len(calls) == 3
async def test_automation_trigger_variables(hass, caplog):
"""Test automation trigger variables."""
calls = async_mock_service(hass, "test", "automation")
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"variables": {
"event_type": "{{ trigger.event.event_type }}",
},
"trigger_variables": {
"test_var": "defined_in_config",
},
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {
"service": "test.automation",
"data": {
"value": "{{ test_var }}",
"event_type": "{{ event_type }}",
},
},
},
{
"variables": {
"event_type": "{{ trigger.event.event_type }}",
"test_var": "overridden_in_config",
},
"trigger_variables": {
"test_var": "defined_in_config",
},
"trigger": {"platform": "event", "event_type": "test_event_2"},
"action": {
"service": "test.automation",
"data": {
"value": "{{ test_var }}",
"event_type": "{{ event_type }}",
},
},
},
]
},
)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["value"] == "defined_in_config"
assert calls[0].data["event_type"] == "test_event"
hass.bus.async_fire("test_event_2")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[1].data["value"] == "overridden_in_config"
assert calls[1].data["event_type"] == "test_event_2"
assert "Error rendering variables" not in caplog.text
async def test_automation_bad_trigger_variables(hass, caplog):
"""Test automation trigger variables accessing hass is rejected."""
calls = async_mock_service(hass, "test", "automation")
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger_variables": {
"test_var": "{{ states('foo.bar') }}",
},
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {
"service": "test.automation",
},
},
]
},
)
hass.bus.async_fire("test_event")
assert "Use of 'states' is not supported in limited templates" in caplog.text
await hass.async_block_till_done()
assert len(calls) == 0
async def test_blueprint_automation(hass, calls): async def test_blueprint_automation(hass, calls):
"""Test blueprint automation.""" """Test blueprint automation."""
assert await async_setup_component( assert await async_setup_component(

View File

@ -81,6 +81,58 @@ async def test_if_fires_on_topic_and_payload_match(hass, calls):
assert len(calls) == 1 assert len(calls) == 1
async def test_if_fires_on_templated_topic_and_payload_match(hass, calls):
"""Test if message is fired on templated topic and payload match."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {
"platform": "mqtt",
"topic": "test-topic-{{ sqrt(16)|round }}",
"payload": '{{ "foo"|regex_replace("foo", "bar") }}',
},
"action": {"service": "test.automation"},
}
},
)
async_fire_mqtt_message(hass, "test-topic-", "foo")
await hass.async_block_till_done()
assert len(calls) == 0
async_fire_mqtt_message(hass, "test-topic-4", "foo")
await hass.async_block_till_done()
assert len(calls) == 0
async_fire_mqtt_message(hass, "test-topic-4", "bar")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_non_allowed_templates(hass, calls, caplog):
"""Test non allowed function in template."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {
"platform": "mqtt",
"topic": "test-topic-{{ states() }}",
},
"action": {"service": "test.automation"},
}
},
)
assert (
"Got error 'TemplateError: str: Use of 'states' is not supported in limited templates' when setting up triggers"
in caplog.text
)
async def test_if_not_fires_on_topic_but_no_payload_match(hass, calls): async def test_if_not_fires_on_topic_but_no_payload_match(hass, calls):
"""Test if message is not fired on topic but no payload.""" """Test if message is not fired on topic but no payload."""
assert await async_setup_component( assert await async_setup_component(