Support templating MQTT triggers (#45614)

* Add support for limited templates (no HASS access)

* Pass variables to automation triggers

* Support templates in MQTT triggers

* Spelling

* Handle trigger referenced by variables

* Raise on unsupported function in limited templates

* Validate MQTT trigger schema in MQTT device trigger

* Add trigger_variables to automation config schema

* Don't print stacktrace when setting up trigger throws

* Make pylint happy

* Add trigger_variables to variables

* Add debug prints, document limited template

* Add tests

* Validate MQTT trigger topic early when possible

* Improve valid_subscribe_topic_template
This commit is contained in:
Erik Montnemery 2021-02-08 10:50:38 +01:00 committed by GitHub
parent b9b1caf4d7
commit 047f16772f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 262 additions and 15 deletions

View File

@ -60,6 +60,7 @@ from .const import (
CONF_ACTION,
CONF_INITIAL_STATE,
CONF_TRIGGER,
CONF_TRIGGER_VARIABLES,
DEFAULT_INITIAL_STATE,
DOMAIN,
LOGGER,
@ -221,6 +222,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
action_script,
initial_state,
variables,
trigger_variables,
):
"""Initialize an automation entity."""
self._id = automation_id
@ -236,6 +238,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._referenced_devices: Optional[Set[str]] = None
self._logger = LOGGER
self._variables: ScriptVariables = variables
self._trigger_variables: ScriptVariables = trigger_variables
@property
def name(self):
@ -471,6 +474,16 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
def log_cb(level, msg, **kwargs):
self._logger.log(level, "%s %s", msg, self._name, **kwargs)
variables = None
if self._trigger_variables:
try:
variables = self._trigger_variables.async_render(
cast(HomeAssistant, self.hass), None, limited=True
)
except template.TemplateError as err:
self._logger.error("Error rendering trigger variables: %s", err)
return None
return await async_initialize_triggers(
cast(HomeAssistant, self.hass),
self._trigger_config,
@ -479,6 +492,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
self._name,
log_cb,
home_assistant_start,
variables,
)
@property
@ -556,6 +570,18 @@ async def _async_process_config(
else:
cond_func = None
# Add trigger variables to variables
variables = None
if CONF_TRIGGER_VARIABLES in config_block:
variables = ScriptVariables(
dict(config_block[CONF_TRIGGER_VARIABLES].as_dict())
)
if CONF_VARIABLES in config_block:
if variables:
variables.variables.update(config_block[CONF_VARIABLES].as_dict())
else:
variables = config_block[CONF_VARIABLES]
entity = AutomationEntity(
automation_id,
name,
@ -563,7 +589,8 @@ async def _async_process_config(
cond_func,
action_script,
initial_state,
config_block.get(CONF_VARIABLES),
variables,
config_block.get(CONF_TRIGGER_VARIABLES),
)
entities.append(entity)

View File

@ -21,6 +21,7 @@ from .const import (
CONF_HIDE_ENTITY,
CONF_INITIAL_STATE,
CONF_TRIGGER,
CONF_TRIGGER_VARIABLES,
DOMAIN,
)
from .helpers import async_get_blueprints
@ -43,6 +44,7 @@ PLATFORM_SCHEMA = vol.All(
vol.Required(CONF_TRIGGER): cv.TRIGGER_SCHEMA,
vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA,
vol.Optional(CONF_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Optional(CONF_TRIGGER_VARIABLES): cv.SCRIPT_VARIABLES_SCHEMA,
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
},
script.SCRIPT_MODE_SINGLE,

View File

@ -3,6 +3,7 @@ import logging
CONF_ACTION = "action"
CONF_TRIGGER = "trigger"
CONF_TRIGGER_VARIABLES = "trigger_variables"
DOMAIN = "automation"
CONF_DESCRIPTION = "description"

View File

@ -89,12 +89,14 @@ class TriggerInstance:
async def async_attach_trigger(self):
"""Attach MQTT trigger."""
mqtt_config = {
mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN,
mqtt_trigger.CONF_TOPIC: self.trigger.topic,
mqtt_trigger.CONF_ENCODING: DEFAULT_ENCODING,
mqtt_trigger.CONF_QOS: self.trigger.qos,
}
if self.trigger.payload:
mqtt_config[CONF_PAYLOAD] = self.trigger.payload
mqtt_config = mqtt_trigger.TRIGGER_SCHEMA(mqtt_config)
if self.remove:
self.remove()

View File

@ -1,11 +1,12 @@
"""Offer MQTT listening automation rules."""
import json
import logging
import voluptuous as vol
from homeassistant.const import CONF_PAYLOAD, CONF_PLATFORM
from homeassistant.core import HassJob, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers import config_validation as cv, template
from .. import mqtt
@ -20,8 +21,8 @@ DEFAULT_QOS = 0
TRIGGER_SCHEMA = vol.Schema(
{
vol.Required(CONF_PLATFORM): mqtt.DOMAIN,
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic,
vol.Optional(CONF_PAYLOAD): cv.string,
vol.Required(CONF_TOPIC): mqtt.util.valid_subscribe_topic_template,
vol.Optional(CONF_PAYLOAD): cv.template,
vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string,
vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All(
vol.Coerce(int), vol.In([0, 1, 2])
@ -29,6 +30,8 @@ TRIGGER_SCHEMA = vol.Schema(
}
)
_LOGGER = logging.getLogger(__name__)
async def async_attach_trigger(hass, config, action, automation_info):
"""Listen for state changes based on configuration."""
@ -37,6 +40,18 @@ async def async_attach_trigger(hass, config, action, automation_info):
encoding = config[CONF_ENCODING] or None
qos = config[CONF_QOS]
job = HassJob(action)
variables = None
if automation_info:
variables = automation_info.get("variables")
template.attach(hass, payload)
if payload:
payload = payload.async_render(variables, limited=True)
template.attach(hass, topic)
if isinstance(topic, template.Template):
topic = topic.async_render(variables, limited=True)
topic = mqtt.util.valid_subscribe_topic(topic)
@callback
def mqtt_automation_listener(mqttmsg):
@ -57,6 +72,10 @@ async def async_attach_trigger(hass, config, action, automation_info):
hass.async_run_hass_job(job, {"trigger": data})
_LOGGER.debug(
"Attaching MQTT trigger for topic: '%s', payload: '%s'", topic, payload
)
remove = await mqtt.async_subscribe(
hass, topic, mqtt_automation_listener, encoding=encoding, qos=qos
)

View File

@ -4,7 +4,7 @@ from typing import Any
import voluptuous as vol
from homeassistant.const import CONF_PAYLOAD
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import config_validation as cv, template
from .const import (
ATTR_PAYLOAD,
@ -61,6 +61,16 @@ def valid_subscribe_topic(value: Any) -> str:
return value
def valid_subscribe_topic_template(value: Any) -> template.Template:
"""Validate either a jinja2 template or a valid MQTT subscription topic."""
tpl = template.Template(value)
if tpl.is_static:
valid_subscribe_topic(value)
return tpl
def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
value = valid_topic(value)

View File

@ -572,7 +572,7 @@ def dynamic_template(value: Optional[Any]) -> template_helper.Template:
if isinstance(value, (list, dict, template_helper.Template)):
raise vol.Invalid("template value should be a string")
if not template_helper.is_template_string(str(value)):
raise vol.Invalid("template value does not contain a dynmamic template")
raise vol.Invalid("template value does not contain a dynamic template")
template_value = template_helper.Template(str(value)) # type: ignore
try:

View File

@ -21,6 +21,7 @@ class ScriptVariables:
run_variables: Optional[Mapping[str, Any]],
*,
render_as_defaults: bool = True,
limited: bool = False,
) -> Dict[str, Any]:
"""Render script variables.
@ -55,7 +56,9 @@ class ScriptVariables:
if render_as_defaults and key in rendered_variables:
continue
rendered_variables[key] = template.render_complex(value, rendered_variables)
rendered_variables[key] = template.render_complex(
value, rendered_variables, limited
)
return rendered_variables

View File

@ -84,7 +84,9 @@ def attach(hass: HomeAssistantType, obj: Any) -> None:
obj.hass = hass
def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
def render_complex(
value: Any, variables: TemplateVarsType = None, limited: bool = False
) -> Any:
"""Recursive template creator helper function."""
if isinstance(value, list):
return [render_complex(item, variables) for item in value]
@ -94,7 +96,7 @@ def render_complex(value: Any, variables: TemplateVarsType = None) -> Any:
for key, item in value.items()
}
if isinstance(value, Template):
return value.async_render(variables)
return value.async_render(variables, limited=limited)
return value
@ -279,6 +281,7 @@ class Template:
"is_static",
"_compiled_code",
"_compiled",
"_limited",
)
def __init__(self, template, hass=None):
@ -291,10 +294,11 @@ class Template:
self._compiled: Optional[Template] = None
self.hass = hass
self.is_static = not is_template_string(template)
self._limited = None
@property
def _env(self) -> "TemplateEnvironment":
if self.hass is None:
if self.hass is None or self._limited:
return _NO_HASS_ENV
ret: Optional[TemplateEnvironment] = self.hass.data.get(_ENVIRONMENT)
if ret is None:
@ -315,9 +319,13 @@ class Template:
self,
variables: TemplateVarsType = None,
parse_result: bool = True,
limited: bool = False,
**kwargs: Any,
) -> Any:
"""Render given template."""
"""Render given template.
If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine.
"""
if self.is_static:
if self.hass.config.legacy_templates or not parse_result:
return self.template
@ -325,7 +333,7 @@ class Template:
return run_callback_threadsafe(
self.hass.loop,
partial(self.async_render, variables, parse_result, **kwargs),
partial(self.async_render, variables, parse_result, limited, **kwargs),
).result()
@callback
@ -333,18 +341,21 @@ class Template:
self,
variables: TemplateVarsType = None,
parse_result: bool = True,
limited: bool = False,
**kwargs: Any,
) -> Any:
"""Render given template.
This method must be run in the event loop.
If limited is True, the template is not allowed to access any function or filter depending on hass or the state machine.
"""
if self.is_static:
if self.hass.config.legacy_templates or not parse_result:
return self.template
return self._parse_result(self.template)
compiled = self._compiled or self._ensure_compiled()
compiled = self._compiled or self._ensure_compiled(limited)
if variables is not None:
kwargs.update(variables)
@ -519,12 +530,16 @@ class Template:
)
return value if error_value is _SENTINEL else error_value
def _ensure_compiled(self) -> "Template":
def _ensure_compiled(self, limited: bool = False) -> "Template":
"""Bind a template to a specific hass instance."""
self.ensure_valid()
assert self.hass is not None, "hass variable not set on template"
assert (
self._limited is None or self._limited == limited
), "can't change between limited and non limited template"
self._limited = limited
env = self._env
self._compiled = cast(
@ -1352,6 +1367,31 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
self.globals["strptime"] = strptime
self.globals["urlencode"] = urlencode
if hass is None:
def unsupported(name):
def warn_unsupported(*args, **kwargs):
raise TemplateError(
f"Use of '{name}' is not supported in limited templates"
)
return warn_unsupported
hass_globals = [
"closest",
"distance",
"expand",
"is_state",
"is_state_attr",
"state_attr",
"states",
"utcnow",
"now",
]
hass_filters = ["closest", "expand"]
for glob in hass_globals:
self.globals[glob] = unsupported(glob)
for filt in hass_filters:
self.filters[filt] = unsupported(filt)
return
# We mark these as a context functions to ensure they get

View File

@ -8,6 +8,7 @@ import voluptuous as vol
from homeassistant.const import CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.loader import IntegrationNotFound, async_get_integration
@ -79,7 +80,9 @@ async def async_initialize_triggers(
removes = []
for result in attach_results:
if isinstance(result, Exception):
if isinstance(result, HomeAssistantError):
log_cb(logging.ERROR, f"Got error '{result}' when setting up triggers for")
elif isinstance(result, Exception):
log_cb(logging.ERROR, "Error setting up trigger", exc_info=result)
elif result is None:
log_cb(

View File

@ -1237,6 +1237,94 @@ async def test_automation_variables(hass, caplog):
assert len(calls) == 3
async def test_automation_trigger_variables(hass, caplog):
"""Test automation trigger variables."""
calls = async_mock_service(hass, "test", "automation")
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"variables": {
"event_type": "{{ trigger.event.event_type }}",
},
"trigger_variables": {
"test_var": "defined_in_config",
},
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {
"service": "test.automation",
"data": {
"value": "{{ test_var }}",
"event_type": "{{ event_type }}",
},
},
},
{
"variables": {
"event_type": "{{ trigger.event.event_type }}",
"test_var": "overridden_in_config",
},
"trigger_variables": {
"test_var": "defined_in_config",
},
"trigger": {"platform": "event", "event_type": "test_event_2"},
"action": {
"service": "test.automation",
"data": {
"value": "{{ test_var }}",
"event_type": "{{ event_type }}",
},
},
},
]
},
)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["value"] == "defined_in_config"
assert calls[0].data["event_type"] == "test_event"
hass.bus.async_fire("test_event_2")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[1].data["value"] == "overridden_in_config"
assert calls[1].data["event_type"] == "test_event_2"
assert "Error rendering variables" not in caplog.text
async def test_automation_bad_trigger_variables(hass, caplog):
"""Test automation trigger variables accessing hass is rejected."""
calls = async_mock_service(hass, "test", "automation")
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: [
{
"trigger_variables": {
"test_var": "{{ states('foo.bar') }}",
},
"trigger": {"platform": "event", "event_type": "test_event"},
"action": {
"service": "test.automation",
},
},
]
},
)
hass.bus.async_fire("test_event")
assert "Use of 'states' is not supported in limited templates" in caplog.text
await hass.async_block_till_done()
assert len(calls) == 0
async def test_blueprint_automation(hass, calls):
"""Test blueprint automation."""
assert await async_setup_component(

View File

@ -81,6 +81,58 @@ async def test_if_fires_on_topic_and_payload_match(hass, calls):
assert len(calls) == 1
async def test_if_fires_on_templated_topic_and_payload_match(hass, calls):
"""Test if message is fired on templated topic and payload match."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {
"platform": "mqtt",
"topic": "test-topic-{{ sqrt(16)|round }}",
"payload": '{{ "foo"|regex_replace("foo", "bar") }}',
},
"action": {"service": "test.automation"},
}
},
)
async_fire_mqtt_message(hass, "test-topic-", "foo")
await hass.async_block_till_done()
assert len(calls) == 0
async_fire_mqtt_message(hass, "test-topic-4", "foo")
await hass.async_block_till_done()
assert len(calls) == 0
async_fire_mqtt_message(hass, "test-topic-4", "bar")
await hass.async_block_till_done()
assert len(calls) == 1
async def test_non_allowed_templates(hass, calls, caplog):
"""Test non allowed function in template."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {
"platform": "mqtt",
"topic": "test-topic-{{ states() }}",
},
"action": {"service": "test.automation"},
}
},
)
assert (
"Got error 'TemplateError: str: Use of 'states' is not supported in limited templates' when setting up triggers"
in caplog.text
)
async def test_if_not_fires_on_topic_but_no_payload_match(hass, calls):
"""Test if message is not fired on topic but no payload."""
assert await async_setup_component(