1
mirror of https://github.com/home-assistant/core synced 2024-08-15 18:25:44 +02:00

Resolve traceback error when using variables in template triggers (#77287)

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
ehendrix23 2022-09-29 11:26:28 -06:00 committed by GitHub
parent ee32e0eb3f
commit ba6a81c565
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 15 deletions

View File

@ -2,10 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Coroutine
import functools
import logging
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
import voluptuous as vol
@ -16,7 +16,13 @@ from homeassistant.const import (
CONF_PLATFORM,
CONF_VARIABLES,
)
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
from homeassistant.core import (
CALLBACK_TYPE,
Context,
HomeAssistant,
callback,
is_callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import IntegrationNotFound, async_get_integration
@ -101,20 +107,51 @@ async def async_validate_trigger_config(
def _trigger_action_wrapper(
hass: HomeAssistant, action: Callable, conf: ConfigType
) -> Callable:
"""Wrap trigger action with extra vars if configured."""
"""Wrap trigger action with extra vars if configured.
If action is a coroutine function, a coroutine function will be returned.
If action is a callback, a callback will be returned.
"""
if CONF_VARIABLES not in conf:
return action
@functools.wraps(action)
async def with_vars(
run_variables: dict[str, Any], context: Context | None = None
) -> None:
"""Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
await action(run_variables, context)
# Check for partials to properly determine if coroutine function
check_func = action
while isinstance(check_func, functools.partial):
check_func = check_func.func
return with_vars
wrapper_func: Callable[..., None] | Callable[..., Coroutine[Any, Any, None]]
if asyncio.iscoroutinefunction(check_func):
async_action = cast(Callable[..., Coroutine[Any, Any, None]], action)
@functools.wraps(async_action)
async def async_with_vars(
run_variables: dict[str, Any], context: Context | None = None
) -> None:
"""Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
await action(run_variables, context)
wrapper_func = async_with_vars
else:
@functools.wraps(action)
async def with_vars(
run_variables: dict[str, Any], context: Context | None = None
) -> None:
"""Wrap action with extra vars."""
trigger_variables = conf[CONF_VARIABLES]
run_variables.update(trigger_variables.async_render(hass, run_variables))
action(run_variables, context)
if is_callback(check_func):
with_vars = callback(with_vars)
wrapper_func = with_vars
return wrapper_func
async def async_initialize_triggers(

View File

@ -1,12 +1,13 @@
"""The tests for the trigger helper."""
from unittest.mock import MagicMock, call, patch
from unittest.mock import ANY, MagicMock, call, patch
import pytest
import voluptuous as vol
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers.trigger import (
_async_get_trigger_platform,
async_initialize_triggers,
async_validate_trigger_config,
)
from homeassistant.setup import async_setup_component
@ -137,3 +138,62 @@ async def test_trigger_alias(
"Automation trigger 'My event' triggered by event 'trigger_event'"
in caplog.text
)
async def test_async_initialize_triggers(
hass: HomeAssistant, calls: list[ServiceCall], caplog: pytest.LogCaptureFixture
) -> None:
"""Test async_initialize_triggers with different action types."""
log_cb = MagicMock()
action_calls = []
trigger_config = await async_validate_trigger_config(
hass,
[
{
"platform": "event",
"event_type": ["trigger_event"],
"variables": {
"name": "Paulus",
"via_event": "{{ trigger.event.event_type }}",
},
}
],
)
async def async_action(*args):
action_calls.append([*args])
@callback
def cb_action(*args):
action_calls.append([*args])
def non_cb_action(*args):
action_calls.append([*args])
for action in (async_action, cb_action, non_cb_action):
action_calls = []
unsub = await async_initialize_triggers(
hass,
trigger_config,
action,
"test",
"",
log_cb,
)
await hass.async_block_till_done()
hass.bus.async_fire("trigger_event")
await hass.async_block_till_done()
await hass.async_block_till_done()
assert len(action_calls) == 1
assert action_calls[0][0]["name"] == "Paulus"
assert action_calls[0][0]["via_event"] == "trigger_event"
log_cb.assert_called_once_with(ANY, "Initialized trigger")
log_cb.reset_mock()
unsub()