1
mirror of https://github.com/home-assistant/core synced 2024-09-09 12:51:22 +02:00

Re-org device automations (#67064)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2022-02-22 13:15:16 -08:00 committed by GitHub
parent 9950e543df
commit c2e62e4d9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 188 additions and 127 deletions

View File

@ -7,14 +7,14 @@ from enum import Enum
from functools import wraps
import logging
from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Protocol, Union, overload
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Union, overload
import voluptuous as vol
import voluptuous_serialize
from homeassistant.components import websocket_api
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant
from homeassistant.core import HomeAssistant
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
@ -28,11 +28,16 @@ from homeassistant.requirements import async_get_integration_with_requirements
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
if TYPE_CHECKING:
from homeassistant.components.automation import (
AutomationActionType,
AutomationTriggerInfo,
)
from homeassistant.helpers import condition
from .action import DeviceAutomationActionProtocol
from .condition import DeviceAutomationConditionProtocol
from .trigger import DeviceAutomationTriggerProtocol
DeviceAutomationPlatformType = Union[
ModuleType,
DeviceAutomationTriggerProtocol,
DeviceAutomationConditionProtocol,
DeviceAutomationActionProtocol,
]
# mypy: allow-untyped-calls, allow-untyped-defs
@ -83,77 +88,6 @@ TYPES = {
}
class DeviceAutomationTriggerProtocol(Protocol):
"""Define the format of device_trigger modules.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""
TRIGGER_SCHEMA: vol.Schema
async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
raise NotImplementedError
class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
"""
CONDITION_SCHEMA: vol.Schema
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
raise NotImplementedError
class DeviceAutomationActionProtocol(Protocol):
"""Define the format of device_action modules.
Each module must define either ACTION_SCHEMA or async_validate_action_config.
"""
ACTION_SCHEMA: vol.Schema
async def async_validate_action_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
async def async_call_action_from_config(
self,
hass: HomeAssistant,
config: ConfigType,
variables: dict[str, Any],
context: Context | None,
) -> None:
"""Execute a device action."""
raise NotImplementedError
@bind_hass
async def async_get_device_automations(
hass: HomeAssistant,
@ -193,14 +127,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
DeviceAutomationPlatformType = Union[
ModuleType,
DeviceAutomationTriggerProtocol,
DeviceAutomationConditionProtocol,
DeviceAutomationActionProtocol,
]
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant,
@ -231,13 +157,13 @@ async def async_get_device_automation_platform( # noqa: D103
@overload
async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> DeviceAutomationPlatformType:
) -> "DeviceAutomationPlatformType":
...
async def async_get_device_automation_platform(
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str
) -> DeviceAutomationPlatformType:
) -> "DeviceAutomationPlatformType":
"""Load device automation platform for integration.
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.

View File

@ -0,0 +1,68 @@
"""Device action validator."""
from __future__ import annotations
from typing import Any, Protocol, cast
import voluptuous as vol
from homeassistant.const import CONF_DOMAIN
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from . import DeviceAutomationType, async_get_device_automation_platform
from .exceptions import InvalidDeviceAutomationConfig
class DeviceAutomationActionProtocol(Protocol):
"""Define the format of device_action modules.
Each module must define either ACTION_SCHEMA or async_validate_action_config.
"""
ACTION_SCHEMA: vol.Schema
async def async_validate_action_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
async def async_call_action_from_config(
self,
hass: HomeAssistant,
config: ConfigType,
variables: dict[str, Any],
context: Context | None,
) -> None:
"""Execute a device action."""
raise NotImplementedError
async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
try:
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.ACTION
)
if hasattr(platform, "async_validate_action_config"):
return await platform.async_validate_action_config(hass, config)
return cast(ConfigType, platform.ACTION_SCHEMA(config))
except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid action configuration") from err
async def async_call_action_from_config(
hass: HomeAssistant,
config: ConfigType,
variables: dict[str, Any],
context: Context | None,
) -> None:
"""Execute a device action."""
platform = await async_get_device_automation_platform(
hass,
config[CONF_DOMAIN],
DeviceAutomationType.ACTION,
)
await platform.async_call_action_from_config(hass, config, variables, context)

View File

@ -0,0 +1,64 @@
"""Validate device conditions."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol, cast
import voluptuous as vol
from homeassistant.const import CONF_DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType
from . import DeviceAutomationType, async_get_device_automation_platform
from .exceptions import InvalidDeviceAutomationConfig
if TYPE_CHECKING:
from homeassistant.helpers import condition
class DeviceAutomationConditionProtocol(Protocol):
"""Define the format of device_condition modules.
Each module must define either CONDITION_SCHEMA or async_validate_condition_config.
"""
CONDITION_SCHEMA: vol.Schema
async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
def async_condition_from_config(
self, hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Evaluate state based on configuration."""
raise NotImplementedError
async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate device condition config."""
try:
config = cv.DEVICE_CONDITION_SCHEMA(config)
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
if hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config)
return cast(ConfigType, platform.CONDITION_SCHEMA(config))
except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid condition configuration") from err
async def async_condition_from_config(
hass: HomeAssistant, config: ConfigType
) -> condition.ConditionCheckerType:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return platform.async_condition_from_config(hass, config)

View File

@ -1,5 +1,5 @@
"""Offer device oriented automation."""
from typing import cast
from typing import Protocol, cast
import voluptuous as vol
@ -21,17 +21,41 @@ from .exceptions import InvalidDeviceAutomationConfig
TRIGGER_SCHEMA = DEVICE_TRIGGER_BASE_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
class DeviceAutomationTriggerProtocol(Protocol):
"""Define the format of device_trigger modules.
Each module must define either TRIGGER_SCHEMA or async_validate_trigger_config.
"""
TRIGGER_SCHEMA: vol.Schema
async def async_validate_trigger_config(
self, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
raise NotImplementedError
async def async_attach_trigger(
self,
hass: HomeAssistant,
config: ConfigType,
action: AutomationActionType,
automation_info: AutomationTriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
raise NotImplementedError
async def async_validate_trigger_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
)
if not hasattr(platform, "async_validate_trigger_config"):
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
try:
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.TRIGGER
)
if not hasattr(platform, "async_validate_trigger_config"):
return cast(ConfigType, platform.TRIGGER_SCHEMA(config))
return await platform.async_validate_trigger_config(hass, config)
except InvalidDeviceAutomationConfig as err:
raise vol.Invalid(str(err) or "Invalid trigger configuration") from err

View File

@ -13,10 +13,7 @@ import sys
from typing import Any, cast
from homeassistant.components import zone as zone_cmp
from homeassistant.components.device_automation import (
DeviceAutomationType,
async_get_device_automation_platform,
)
from homeassistant.components.device_automation import condition as device_condition
from homeassistant.components.sensor import SensorDeviceClass
from homeassistant.const import (
ATTR_DEVICE_CLASS,
@ -30,7 +27,6 @@ from homeassistant.const import (
CONF_BELOW,
CONF_CONDITION,
CONF_DEVICE_ID,
CONF_DOMAIN,
CONF_ENTITY_ID,
CONF_ID,
CONF_STATE,
@ -872,10 +868,8 @@ async def async_device_from_config(
hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
return trace_condition_function(platform.async_condition_from_config(hass, config))
checker = await device_condition.async_condition_from_config(hass, config)
return trace_condition_function(checker)
async def async_trigger_from_config(
@ -931,15 +925,10 @@ async def async_validate_condition_config(
sub_cond = await async_validate_condition_config(hass, sub_cond)
conditions.append(sub_cond)
config["conditions"] = conditions
return config
if condition == "device":
config = cv.DEVICE_CONDITION_SCHEMA(config)
platform = await async_get_device_automation_platform(
hass, config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
if hasattr(platform, "async_validate_condition_config"):
return await platform.async_validate_condition_config(hass, config)
return cast(ConfigType, platform.CONDITION_SCHEMA(config))
return await device_condition.async_validate_condition_config(hass, config)
if condition in ("numeric_state", "state"):
validator = cast(

View File

@ -15,7 +15,8 @@ import async_timeout
import voluptuous as vol
from homeassistant import exceptions
from homeassistant.components import device_automation, scene
from homeassistant.components import scene
from homeassistant.components.device_automation import action as device_action
from homeassistant.components.logger import LOGSEVERITY
from homeassistant.const import (
ATTR_AREA_ID,
@ -244,13 +245,7 @@ async def async_validate_action_config(
pass
elif action_type == cv.SCRIPT_ACTION_DEVICE_AUTOMATION:
platform = await device_automation.async_get_device_automation_platform(
hass, config[CONF_DOMAIN], device_automation.DeviceAutomationType.ACTION
)
if hasattr(platform, "async_validate_action_config"):
config = await platform.async_validate_action_config(hass, config)
else:
config = platform.ACTION_SCHEMA(config)
config = await device_action.async_validate_action_config(hass, config)
elif action_type == cv.SCRIPT_ACTION_CHECK_CONDITION:
config = await condition.async_validate_condition_config(hass, config)
@ -580,12 +575,7 @@ class _ScriptRun:
async def _async_device_step(self):
"""Perform the device automation specified in the action."""
self._step_log("device automation")
platform = await device_automation.async_get_device_automation_platform(
self._hass,
self._action[CONF_DOMAIN],
device_automation.DeviceAutomationType.ACTION,
)
await platform.async_call_action_from_config(
await device_action.async_call_action_from_config(
self._hass, self._action, self._variables, self._context
)

View File

@ -2977,7 +2977,7 @@ async def test_platform_async_validate_condition_config(hass):
config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test", CONF_CONDITION: "device"}
platform = AsyncMock()
with patch(
"homeassistant.helpers.condition.async_get_device_automation_platform",
"homeassistant.components.device_automation.condition.async_get_device_automation_platform",
return_value=platform,
):
platform.async_validate_condition_config.return_value = config

View File

@ -3721,7 +3721,7 @@ async def test_platform_async_validate_action_config(hass):
config = {CONF_DEVICE_ID: "test", CONF_DOMAIN: "test"}
platform = AsyncMock()
with patch(
"homeassistant.helpers.script.device_automation.async_get_device_automation_platform",
"homeassistant.components.device_automation.action.async_get_device_automation_platform",
return_value=platform,
):
platform.async_validate_action_config.return_value = config