1
mirror of https://github.com/home-assistant/core synced 2024-07-18 12:02:20 +02:00

Refactor template entity to allow reuse (#72753)

* Refactor template entity to allow reuse

* Fix schema and default name

* Add tests

* Update requirements

* Improve test

* Tweak TemplateSensor initializer

* Drop attributes and availability from TemplateEntity

* Use rest sensor for proof of concept

* Revert changes in SNMP sensor

* Don't set _attr_should_poll in mixin class

* Update requirements
This commit is contained in:
Erik Montnemery 2022-06-08 15:55:49 +02:00 committed by GitHub
parent 79096864eb
commit 5987266e56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 580 additions and 472 deletions

View File

@ -10,29 +10,21 @@ from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from .data import RestData
class RestEntity(Entity):
class BaseRestEntity(Entity):
"""A class for entities using DataUpdateCoordinator or rest data directly."""
def __init__(
self,
coordinator: DataUpdateCoordinator[Any],
rest: RestData,
name,
resource_template,
force_update,
) -> None:
"""Create the entity that may have a coordinator."""
self.coordinator = coordinator
self.rest = rest
self._name = name
self._resource_template = resource_template
self._force_update = force_update
super().__init__()
@property
def name(self):
"""Return the name of the sensor."""
return self._name
@property
def force_update(self):
@ -41,7 +33,7 @@ class RestEntity(Entity):
@property
def should_poll(self) -> bool:
"""Poll only if we do noty have a coordinator."""
"""Poll only if we do not have a coordinator."""
return not self.coordinator
@property
@ -80,3 +72,24 @@ class RestEntity(Entity):
@abstractmethod
def _update_from_rest_data(self):
"""Update state from the rest data."""
class RestEntity(BaseRestEntity):
"""A class for entities using DataUpdateCoordinator or rest data directly."""
def __init__(
self,
coordinator: DataUpdateCoordinator[Any],
rest: RestData,
name,
resource_template,
force_update,
) -> None:
"""Create the entity that may have a coordinator."""
self._name = name
super().__init__(coordinator, rest, resource_template, force_update)
@property
def name(self):
"""Return the name of the sensor."""
return self._name

View File

@ -6,12 +6,7 @@ from homeassistant.components.binary_sensor import (
DEVICE_CLASSES_SCHEMA as BINARY_SENSOR_DEVICE_CLASSES_SCHEMA,
DOMAIN as BINARY_SENSOR_DOMAIN,
)
from homeassistant.components.sensor import (
DEVICE_CLASSES_SCHEMA as SENSOR_DEVICE_CLASSES_SCHEMA,
DOMAIN as SENSOR_DOMAIN,
STATE_CLASSES_SCHEMA,
)
from homeassistant.components.sensor.const import CONF_STATE_CLASS
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.const import (
CONF_AUTHENTICATION,
CONF_DEVICE_CLASS,
@ -26,7 +21,6 @@ from homeassistant.const import (
CONF_RESOURCE_TEMPLATE,
CONF_SCAN_INTERVAL,
CONF_TIMEOUT,
CONF_UNIT_OF_MEASUREMENT,
CONF_USERNAME,
CONF_VALUE_TEMPLATE,
CONF_VERIFY_SSL,
@ -34,6 +28,7 @@ from homeassistant.const import (
HTTP_DIGEST_AUTHENTICATION,
)
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.template_entity import TEMPLATE_SENSOR_BASE_SCHEMA
from .const import (
CONF_JSON_ATTRS,
@ -41,7 +36,6 @@ from .const import (
DEFAULT_BINARY_SENSOR_NAME,
DEFAULT_FORCE_UPDATE,
DEFAULT_METHOD,
DEFAULT_SENSOR_NAME,
DEFAULT_VERIFY_SSL,
DOMAIN,
METHODS,
@ -65,10 +59,7 @@ RESOURCE_SCHEMA = {
}
SENSOR_SCHEMA = {
vol.Optional(CONF_NAME, default=DEFAULT_SENSOR_NAME): cv.string,
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
vol.Optional(CONF_DEVICE_CLASS): SENSOR_DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_STATE_CLASS): STATE_CLASSES_SCHEMA,
**TEMPLATE_SENSOR_BASE_SCHEMA.schema,
vol.Optional(CONF_JSON_ATTRS, default=[]): cv.ensure_list_csv,
vol.Optional(CONF_JSON_ATTRS_PATH): cv.string,
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,

View File

@ -10,31 +10,28 @@ import voluptuous as vol
import xmltodict
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
DOMAIN as SENSOR_DOMAIN,
PLATFORM_SCHEMA,
SensorDeviceClass,
SensorEntity,
)
from homeassistant.components.sensor.helpers import async_parse_date_datetime
from homeassistant.const import (
CONF_DEVICE_CLASS,
CONF_FORCE_UPDATE,
CONF_NAME,
CONF_RESOURCE,
CONF_RESOURCE_TEMPLATE,
CONF_UNIT_OF_MEASUREMENT,
CONF_UNIQUE_ID,
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import PlatformNotReady
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.template_entity import TemplateSensor
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import async_get_config_and_coordinator, create_rest_data_from_config
from .const import CONF_JSON_ATTRS, CONF_JSON_ATTRS_PATH
from .entity import RestEntity
from .const import CONF_JSON_ATTRS, CONF_JSON_ATTRS_PATH, DEFAULT_SENSOR_NAME
from .entity import BaseRestEntity
from .schema import RESOURCE_SCHEMA, SENSOR_SCHEMA
_LOGGER = logging.getLogger(__name__)
@ -70,67 +67,54 @@ async def async_setup_platform(
raise PlatformNotReady from rest.last_exception
raise PlatformNotReady
name = conf.get(CONF_NAME)
unit = conf.get(CONF_UNIT_OF_MEASUREMENT)
device_class = conf.get(CONF_DEVICE_CLASS)
state_class = conf.get(CONF_STATE_CLASS)
json_attrs = conf.get(CONF_JSON_ATTRS)
json_attrs_path = conf.get(CONF_JSON_ATTRS_PATH)
value_template = conf.get(CONF_VALUE_TEMPLATE)
force_update = conf.get(CONF_FORCE_UPDATE)
resource_template = conf.get(CONF_RESOURCE_TEMPLATE)
if value_template is not None:
value_template.hass = hass
unique_id = conf.get(CONF_UNIQUE_ID)
async_add_entities(
[
RestSensor(
hass,
coordinator,
rest,
name,
unit,
device_class,
state_class,
value_template,
json_attrs,
force_update,
resource_template,
json_attrs_path,
conf,
unique_id,
)
],
)
class RestSensor(RestEntity, SensorEntity):
class RestSensor(BaseRestEntity, TemplateSensor):
"""Implementation of a REST sensor."""
def __init__(
self,
hass,
coordinator,
rest,
name,
unit_of_measurement,
device_class,
state_class,
value_template,
json_attrs,
force_update,
resource_template,
json_attrs_path,
config,
unique_id,
):
"""Initialize the REST sensor."""
super().__init__(coordinator, rest, name, resource_template, force_update)
BaseRestEntity.__init__(
self,
coordinator,
rest,
config.get(CONF_RESOURCE_TEMPLATE),
config.get(CONF_FORCE_UPDATE),
)
TemplateSensor.__init__(
self,
hass,
config=config,
fallback_name=DEFAULT_SENSOR_NAME,
unique_id=unique_id,
)
self._state = None
self._unit_of_measurement = unit_of_measurement
self._value_template = value_template
self._json_attrs = json_attrs
self._value_template = config.get(CONF_VALUE_TEMPLATE)
if (value_template := self._value_template) is not None:
value_template.hass = hass
self._json_attrs = config.get(CONF_JSON_ATTRS)
self._attributes = None
self._json_attrs_path = json_attrs_path
self._attr_native_unit_of_measurement = self._unit_of_measurement
self._attr_device_class = device_class
self._attr_state_class = state_class
self._json_attrs_path = config.get(CONF_JSON_ATTRS_PATH)
@property
def native_value(self):

View File

@ -125,6 +125,8 @@ async def async_setup_platform(
class AlarmControlPanelTemplate(TemplateEntity, AlarmControlPanelEntity):
"""Representation of a templated Alarm Control Panel."""
_attr_should_poll = False
def __init__(
self,
hass,

View File

@ -195,6 +195,8 @@ async def async_setup_platform(
class BinarySensorTemplate(TemplateEntity, BinarySensorEntity, RestoreEntity):
"""A virtual binary sensor that triggers from another sensor."""
_attr_should_poll = False
def __init__(
self,
hass: HomeAssistant,

View File

@ -78,6 +78,8 @@ async def async_setup_platform(
class TemplateButtonEntity(TemplateEntity, ButtonEntity):
"""Representation of a template button."""
_attr_should_poll = False
def __init__(
self,
hass: HomeAssistant,

View File

@ -133,6 +133,8 @@ async def async_setup_platform(
class CoverTemplate(TemplateEntity, CoverEntity):
"""Representation of a Template cover."""
_attr_should_poll = False
def __init__(
self,
hass,

View File

@ -125,6 +125,8 @@ async def async_setup_platform(
class TemplateFan(TemplateEntity, FanEntity):
"""A template fan component."""
_attr_should_poll = False
def __init__(
self,
hass,

View File

@ -136,6 +136,8 @@ async def async_setup_platform(
class LightTemplate(TemplateEntity, LightEntity):
"""Representation of a templated Light, including dimmable."""
_attr_should_poll = False
def __init__(
self,
hass,

View File

@ -70,6 +70,8 @@ async def async_setup_platform(
class TemplateLock(TemplateEntity, LockEntity):
"""Representation of a template lock."""
_attr_should_poll = False
def __init__(
self,
hass,

View File

@ -100,6 +100,8 @@ async def async_setup_platform(
class TemplateNumber(TemplateEntity, NumberEntity):
"""Representation of a template number."""
_attr_should_poll = False
def __init__(
self,
hass: HomeAssistant,

View File

@ -94,6 +94,8 @@ async def async_setup_platform(
class TemplateSelect(TemplateEntity, SelectEntity):
"""Representation of a template select."""
_attr_should_poll = False
def __init__(
self,
hass: HomeAssistant,

View File

@ -12,10 +12,8 @@ from homeassistant.components.sensor import (
DOMAIN as SENSOR_DOMAIN,
ENTITY_ID_FORMAT,
PLATFORM_SCHEMA,
STATE_CLASSES_SCHEMA,
RestoreSensor,
SensorDeviceClass,
SensorEntity,
)
from homeassistant.components.sensor.helpers import async_parse_date_datetime
from homeassistant.const import (
@ -39,6 +37,10 @@ from homeassistant.exceptions import TemplateError
from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.entity import async_generate_entity_id
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.template_entity import (
TEMPLATE_SENSOR_BASE_SCHEMA,
TemplateSensor,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .const import (
@ -49,7 +51,6 @@ from .const import (
)
from .template_entity import (
TEMPLATE_ENTITY_COMMON_SCHEMA,
TemplateEntity,
rewrite_common_legacy_to_modern_conf,
)
from .trigger_entity import TriggerEntity
@ -61,16 +62,15 @@ LEGACY_FIELDS = {
}
SENSOR_SCHEMA = vol.Schema(
{
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_NAME): cv.template,
vol.Optional(CONF_STATE_CLASS): STATE_CLASSES_SCHEMA,
vol.Required(CONF_STATE): cv.template,
vol.Optional(CONF_UNIQUE_ID): cv.string,
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
}
).extend(TEMPLATE_ENTITY_COMMON_SCHEMA.schema)
SENSOR_SCHEMA = (
vol.Schema(
{
vol.Required(CONF_STATE): cv.template,
}
)
.extend(TEMPLATE_SENSOR_BASE_SCHEMA.schema)
.extend(TEMPLATE_ENTITY_COMMON_SCHEMA.schema)
)
LEGACY_SENSOR_SCHEMA = vol.All(
@ -192,9 +192,11 @@ async def async_setup_platform(
)
class SensorTemplate(TemplateEntity, SensorEntity):
class SensorTemplate(TemplateSensor):
"""Representation of a Template Sensor."""
_attr_should_poll = False
def __init__(
self,
hass: HomeAssistant,
@ -202,17 +204,13 @@ class SensorTemplate(TemplateEntity, SensorEntity):
unique_id: str | None,
) -> None:
"""Initialize the sensor."""
super().__init__(hass, config=config, unique_id=unique_id)
super().__init__(hass, config=config, fallback_name=None, unique_id=unique_id)
self._template = config.get(CONF_STATE)
if (object_id := config.get(CONF_OBJECT_ID)) is not None:
self.entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, object_id, hass=hass
)
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
self._template = config.get(CONF_STATE)
self._attr_device_class = config.get(CONF_DEVICE_CLASS)
self._attr_state_class = config.get(CONF_STATE_CLASS)
async def async_added_to_hass(self):
"""Register callbacks."""
self.add_template_attribute(

View File

@ -90,6 +90,8 @@ async def async_setup_platform(
class SwitchTemplate(TemplateEntity, SwitchEntity, RestoreEntity):
"""Representation of a Template switch."""
_attr_should_poll = False
def __init__(
self,
hass,
@ -149,11 +151,6 @@ class SwitchTemplate(TemplateEntity, SwitchEntity, RestoreEntity):
"""Return true if device is on."""
return self._state
@property
def should_poll(self):
"""Return the polling state."""
return False
async def async_turn_on(self, **kwargs):
"""Fire the on action."""
await self.async_run_script(self._on_script, context=self._context)

View File

@ -1,38 +1,23 @@
"""TemplateEntity utility class."""
from __future__ import annotations
from collections.abc import Callable
import contextlib
import itertools
import logging
from typing import Any
import voluptuous as vol
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_ENTITY_PICTURE_TEMPLATE,
CONF_FRIENDLY_NAME,
CONF_ICON,
CONF_ICON_TEMPLATE,
CONF_NAME,
EVENT_HOMEASSISTANT_START,
STATE_UNKNOWN,
)
from homeassistant.core import Context, CoreState, Event, State, callback
from homeassistant.exceptions import TemplateError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import (
TrackTemplate,
TrackTemplateResult,
async_track_template_result,
)
from homeassistant.helpers.script import Script, _VarsType
from homeassistant.helpers.template import (
Template,
TemplateStateFromEntityId,
result_as_boolean,
from homeassistant.helpers.template import Template
from homeassistant.helpers.template_entity import ( # noqa: F401 pylint: disable=unused-import
TEMPLATE_ENTITY_BASE_SCHEMA,
TemplateEntity,
)
from .const import (
@ -43,9 +28,6 @@ from .const import (
CONF_PICTURE,
)
_LOGGER = logging.getLogger(__name__)
TEMPLATE_ENTITY_AVAILABILITY_SCHEMA = vol.Schema(
{
vol.Optional(CONF_AVAILABILITY): cv.template,
@ -62,10 +44,8 @@ TEMPLATE_ENTITY_COMMON_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ATTRIBUTES): vol.Schema({cv.string: cv.template}),
vol.Optional(CONF_AVAILABILITY): cv.template,
vol.Optional(CONF_ICON): cv.template,
vol.Optional(CONF_PICTURE): cv.template,
}
)
).extend(TEMPLATE_ENTITY_BASE_SCHEMA.schema)
TEMPLATE_ENTITY_ATTRIBUTES_SCHEMA_LEGACY = vol.Schema(
{
@ -121,356 +101,3 @@ def rewrite_common_legacy_to_modern_conf(
entity_cfg[CONF_NAME] = Template(entity_cfg[CONF_NAME])
return entity_cfg
class _TemplateAttribute:
"""Attribute value linked to template result."""
def __init__(
self,
entity: Entity,
attribute: str,
template: Template,
validator: Callable[[Any], Any] = None,
on_update: Callable[[Any], None] | None = None,
none_on_template_error: bool | None = False,
) -> None:
"""Template attribute."""
self._entity = entity
self._attribute = attribute
self.template = template
self.validator = validator
self.on_update = on_update
self.async_update = None
self.none_on_template_error = none_on_template_error
@callback
def async_setup(self):
"""Config update path for the attribute."""
if self.on_update:
return
if not hasattr(self._entity, self._attribute):
raise AttributeError(f"Attribute '{self._attribute}' does not exist.")
self.on_update = self._default_update
@callback
def _default_update(self, result):
attr_result = None if isinstance(result, TemplateError) else result
setattr(self._entity, self._attribute, attr_result)
@callback
def handle_result(
self,
event: Event | None,
template: Template,
last_result: str | None | TemplateError,
result: str | TemplateError,
) -> None:
"""Handle a template result event callback."""
if isinstance(result, TemplateError):
_LOGGER.error(
"TemplateError('%s') "
"while processing template '%s' "
"for attribute '%s' in entity '%s'",
result,
self.template,
self._attribute,
self._entity.entity_id,
)
if self.none_on_template_error:
self._default_update(result)
else:
assert self.on_update
self.on_update(result)
return
if not self.validator:
assert self.on_update
self.on_update(result)
return
try:
validated = self.validator(result)
except vol.Invalid as ex:
_LOGGER.error(
"Error validating template result '%s' "
"from template '%s' "
"for attribute '%s' in entity %s "
"validation message '%s'",
result,
self.template,
self._attribute,
self._entity.entity_id,
ex.msg,
)
assert self.on_update
self.on_update(None)
return
assert self.on_update
self.on_update(validated)
return
class TemplateEntity(Entity):
"""Entity that uses templates to calculate attributes."""
_attr_available = True
_attr_entity_picture = None
_attr_icon = None
_attr_should_poll = False
def __init__(
self,
hass,
*,
availability_template=None,
icon_template=None,
entity_picture_template=None,
attribute_templates=None,
config=None,
fallback_name=None,
unique_id=None,
):
"""Template Entity."""
self._template_attrs = {}
self._async_update = None
self._attr_extra_state_attributes = {}
self._self_ref_update_count = 0
self._attr_unique_id = unique_id
if config is None:
self._attribute_templates = attribute_templates
self._availability_template = availability_template
self._icon_template = icon_template
self._entity_picture_template = entity_picture_template
self._friendly_name_template = None
else:
self._attribute_templates = config.get(CONF_ATTRIBUTES)
self._availability_template = config.get(CONF_AVAILABILITY)
self._icon_template = config.get(CONF_ICON)
self._entity_picture_template = config.get(CONF_PICTURE)
self._friendly_name_template = config.get(CONF_NAME)
class DummyState(State):
"""None-state for template entities not yet added to the state machine."""
def __init__(self) -> None:
"""Initialize a new state."""
super().__init__("unknown.unknown", STATE_UNKNOWN)
self.entity_id = None # type: ignore[assignment]
@property
def name(self) -> str:
"""Name of this state."""
return "<None>"
variables = {"this": DummyState()}
# Try to render the name as it can influence the entity ID
self._attr_name = fallback_name
if self._friendly_name_template:
self._friendly_name_template.hass = hass
with contextlib.suppress(TemplateError):
self._attr_name = self._friendly_name_template.async_render(
variables=variables, parse_result=False
)
# Templates will not render while the entity is unavailable, try to render the
# icon and picture templates.
if self._entity_picture_template:
self._entity_picture_template.hass = hass
with contextlib.suppress(TemplateError):
self._attr_entity_picture = self._entity_picture_template.async_render(
variables=variables, parse_result=False
)
if self._icon_template:
self._icon_template.hass = hass
with contextlib.suppress(TemplateError):
self._attr_icon = self._icon_template.async_render(
variables=variables, parse_result=False
)
@callback
def _update_available(self, result):
if isinstance(result, TemplateError):
self._attr_available = True
return
self._attr_available = result_as_boolean(result)
@callback
def _update_state(self, result):
if self._availability_template:
return
self._attr_available = not isinstance(result, TemplateError)
@callback
def _add_attribute_template(self, attribute_key, attribute_template):
"""Create a template tracker for the attribute."""
def _update_attribute(result):
attr_result = None if isinstance(result, TemplateError) else result
self._attr_extra_state_attributes[attribute_key] = attr_result
self.add_template_attribute(
attribute_key, attribute_template, None, _update_attribute
)
def add_template_attribute(
self,
attribute: str,
template: Template,
validator: Callable[[Any], Any] = None,
on_update: Callable[[Any], None] | None = None,
none_on_template_error: bool = False,
) -> None:
"""
Call in the constructor to add a template linked to a attribute.
Parameters
----------
attribute
The name of the attribute to link to. This attribute must exist
unless a custom on_update method is supplied.
template
The template to calculate.
validator
Validator function to parse the result and ensure it's valid.
on_update
Called to store the template result rather than storing it
the supplied attribute. Passed the result of the validator, or None
if the template or validator resulted in an error.
"""
assert self.hass is not None, "hass cannot be None"
template.hass = self.hass
template_attribute = _TemplateAttribute(
self, attribute, template, validator, on_update, none_on_template_error
)
self._template_attrs.setdefault(template, [])
self._template_attrs[template].append(template_attribute)
@callback
def _handle_results(
self,
event: Event | None,
updates: list[TrackTemplateResult],
) -> None:
"""Call back the results to the attributes."""
if event:
self.async_set_context(event.context)
entity_id = event and event.data.get(ATTR_ENTITY_ID)
if entity_id and entity_id == self.entity_id:
self._self_ref_update_count += 1
else:
self._self_ref_update_count = 0
if self._self_ref_update_count > len(self._template_attrs):
for update in updates:
_LOGGER.warning(
"Template loop detected while processing event: %s, skipping template render for Template[%s]",
event,
update.template.template,
)
return
for update in updates:
for attr in self._template_attrs[update.template]:
attr.handle_result(
event, update.template, update.last_result, update.result
)
self.async_write_ha_state()
async def _async_template_startup(self, *_) -> None:
template_var_tups: list[TrackTemplate] = []
has_availability_template = False
variables = {"this": TemplateStateFromEntityId(self.hass, self.entity_id)}
for template, attributes in self._template_attrs.items():
template_var_tup = TrackTemplate(template, variables)
is_availability_template = False
for attribute in attributes:
# pylint: disable-next=protected-access
if attribute._attribute == "_attr_available":
has_availability_template = True
is_availability_template = True
attribute.async_setup()
# Insert the availability template first in the list
if is_availability_template:
template_var_tups.insert(0, template_var_tup)
else:
template_var_tups.append(template_var_tup)
result_info = async_track_template_result(
self.hass,
template_var_tups,
self._handle_results,
has_super_template=has_availability_template,
)
self.async_on_remove(result_info.async_remove)
self._async_update = result_info.async_refresh
result_info.async_refresh()
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
if self._availability_template is not None:
self.add_template_attribute(
"_attr_available",
self._availability_template,
None,
self._update_available,
)
if self._attribute_templates is not None:
for key, value in self._attribute_templates.items():
self._add_attribute_template(key, value)
if self._icon_template is not None:
self.add_template_attribute(
"_attr_icon", self._icon_template, vol.Or(cv.whitespace, cv.icon)
)
if self._entity_picture_template is not None:
self.add_template_attribute(
"_attr_entity_picture", self._entity_picture_template
)
if (
self._friendly_name_template is not None
and not self._friendly_name_template.is_static
):
self.add_template_attribute("_attr_name", self._friendly_name_template)
if self.hass.state == CoreState.running:
await self._async_template_startup()
return
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, self._async_template_startup
)
async def async_update(self) -> None:
"""Call for forced update."""
self._async_update()
async def async_run_script(
self,
script: Script,
*,
run_variables: _VarsType | None = None,
context: Context | None = None,
) -> None:
"""Run an action script."""
if run_variables is None:
run_variables = {}
return await script.async_run(
run_variables={
"this": TemplateStateFromEntityId(self.hass, self.entity_id),
**run_variables,
},
context=context,
)

View File

@ -126,6 +126,8 @@ async def async_setup_platform(
class TemplateVacuum(TemplateEntity, StateVacuumEntity):
"""A template vacuum component."""
_attr_should_poll = False
def __init__(
self,
hass,

View File

@ -105,6 +105,8 @@ async def async_setup_platform(
class WeatherTemplate(TemplateEntity, WeatherEntity):
"""Representation of a weather condition."""
_attr_should_poll = False
def __init__(
self,
hass,

View File

@ -0,0 +1,434 @@
"""TemplateEntity utility class."""
from __future__ import annotations
from collections.abc import Callable
import contextlib
import logging
from typing import Any
import voluptuous as vol
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
DEVICE_CLASSES_SCHEMA,
STATE_CLASSES_SCHEMA,
SensorEntity,
)
from homeassistant.const import (
ATTR_ENTITY_ID,
CONF_DEVICE_CLASS,
CONF_ICON,
CONF_NAME,
CONF_UNIQUE_ID,
CONF_UNIT_OF_MEASUREMENT,
EVENT_HOMEASSISTANT_START,
STATE_UNKNOWN,
)
from homeassistant.core import Context, CoreState, Event, HomeAssistant, State, callback
from homeassistant.exceptions import TemplateError
from . import config_validation as cv
from .entity import Entity
from .event import TrackTemplate, TrackTemplateResult, async_track_template_result
from .script import Script, _VarsType
from .template import Template, TemplateStateFromEntityId, result_as_boolean
from .typing import ConfigType
_LOGGER = logging.getLogger(__name__)
CONF_AVAILABILITY = "availability"
CONF_ATTRIBUTES = "attributes"
CONF_PICTURE = "picture"
TEMPLATE_ENTITY_BASE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ICON): cv.template,
vol.Optional(CONF_PICTURE): cv.template,
}
)
TEMPLATE_SENSOR_BASE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
vol.Optional(CONF_NAME): cv.template,
vol.Optional(CONF_STATE_CLASS): STATE_CLASSES_SCHEMA,
vol.Optional(CONF_UNIQUE_ID): cv.string,
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
}
).extend(TEMPLATE_ENTITY_BASE_SCHEMA.schema)
class _TemplateAttribute:
"""Attribute value linked to template result."""
def __init__(
self,
entity: Entity,
attribute: str,
template: Template,
validator: Callable[[Any], Any] | None = None,
on_update: Callable[[Any], None] | None = None,
none_on_template_error: bool | None = False,
) -> None:
"""Template attribute."""
self._entity = entity
self._attribute = attribute
self.template = template
self.validator = validator
self.on_update = on_update
self.async_update = None
self.none_on_template_error = none_on_template_error
@callback
def async_setup(self) -> None:
"""Config update path for the attribute."""
if self.on_update:
return
if not hasattr(self._entity, self._attribute):
raise AttributeError(f"Attribute '{self._attribute}' does not exist.")
self.on_update = self._default_update
@callback
def _default_update(self, result: str | TemplateError) -> None:
attr_result = None if isinstance(result, TemplateError) else result
setattr(self._entity, self._attribute, attr_result)
@callback
def handle_result(
self,
event: Event | None,
template: Template,
last_result: str | None | TemplateError,
result: str | TemplateError,
) -> None:
"""Handle a template result event callback."""
if isinstance(result, TemplateError):
_LOGGER.error(
"TemplateError('%s') "
"while processing template '%s' "
"for attribute '%s' in entity '%s'",
result,
self.template,
self._attribute,
self._entity.entity_id,
)
if self.none_on_template_error:
self._default_update(result)
else:
assert self.on_update
self.on_update(result)
return
if not self.validator:
assert self.on_update
self.on_update(result)
return
try:
validated = self.validator(result)
except vol.Invalid as ex:
_LOGGER.error(
"Error validating template result '%s' "
"from template '%s' "
"for attribute '%s' in entity %s "
"validation message '%s'",
result,
self.template,
self._attribute,
self._entity.entity_id,
ex.msg,
)
assert self.on_update
self.on_update(None)
return
assert self.on_update
self.on_update(validated)
return
class TemplateEntity(Entity):
"""Entity that uses templates to calculate attributes."""
_attr_available = True
_attr_entity_picture = None
_attr_icon = None
def __init__(
self,
hass: HomeAssistant,
*,
availability_template: Template | None = None,
icon_template: Template | None = None,
entity_picture_template: Template | None = None,
attribute_templates: dict[str, Template] | None = None,
config: ConfigType | None = None,
fallback_name: str | None = None,
unique_id: str | None = None,
) -> None:
"""Template Entity."""
self._template_attrs: dict[Template, list[_TemplateAttribute]] = {}
self._async_update: Callable[[], None] | None = None
self._attr_extra_state_attributes = {}
self._self_ref_update_count = 0
self._attr_unique_id = unique_id
if config is None:
self._attribute_templates = attribute_templates
self._availability_template = availability_template
self._icon_template = icon_template
self._entity_picture_template = entity_picture_template
self._friendly_name_template = None
else:
self._attribute_templates = config.get(CONF_ATTRIBUTES)
self._availability_template = config.get(CONF_AVAILABILITY)
self._icon_template = config.get(CONF_ICON)
self._entity_picture_template = config.get(CONF_PICTURE)
self._friendly_name_template = config.get(CONF_NAME)
class DummyState(State):
"""None-state for template entities not yet added to the state machine."""
def __init__(self) -> None:
"""Initialize a new state."""
super().__init__("unknown.unknown", STATE_UNKNOWN)
self.entity_id = None # type: ignore[assignment]
@property
def name(self) -> str:
"""Name of this state."""
return "<None>"
variables = {"this": DummyState()}
# Try to render the name as it can influence the entity ID
self._attr_name = fallback_name
if self._friendly_name_template:
self._friendly_name_template.hass = hass
with contextlib.suppress(TemplateError):
self._attr_name = self._friendly_name_template.async_render(
variables=variables, parse_result=False
)
# Templates will not render while the entity is unavailable, try to render the
# icon and picture templates.
if self._entity_picture_template:
self._entity_picture_template.hass = hass
with contextlib.suppress(TemplateError):
self._attr_entity_picture = self._entity_picture_template.async_render(
variables=variables, parse_result=False
)
if self._icon_template:
self._icon_template.hass = hass
with contextlib.suppress(TemplateError):
self._attr_icon = self._icon_template.async_render(
variables=variables, parse_result=False
)
@callback
def _update_available(self, result: str | TemplateError) -> None:
if isinstance(result, TemplateError):
self._attr_available = True
return
self._attr_available = result_as_boolean(result)
@callback
def _update_state(self, result: str | TemplateError) -> None:
if self._availability_template:
return
self._attr_available = not isinstance(result, TemplateError)
@callback
def _add_attribute_template(
self, attribute_key: str, attribute_template: Template
) -> None:
"""Create a template tracker for the attribute."""
def _update_attribute(result: str | TemplateError) -> None:
attr_result = None if isinstance(result, TemplateError) else result
self._attr_extra_state_attributes[attribute_key] = attr_result
self.add_template_attribute(
attribute_key, attribute_template, None, _update_attribute
)
def add_template_attribute(
self,
attribute: str,
template: Template,
validator: Callable[[Any], Any] | None = None,
on_update: Callable[[Any], None] | None = None,
none_on_template_error: bool = False,
) -> None:
"""
Call in the constructor to add a template linked to a attribute.
Parameters
----------
attribute
The name of the attribute to link to. This attribute must exist
unless a custom on_update method is supplied.
template
The template to calculate.
validator
Validator function to parse the result and ensure it's valid.
on_update
Called to store the template result rather than storing it
the supplied attribute. Passed the result of the validator, or None
if the template or validator resulted in an error.
"""
assert self.hass is not None, "hass cannot be None"
template.hass = self.hass
template_attribute = _TemplateAttribute(
self, attribute, template, validator, on_update, none_on_template_error
)
self._template_attrs.setdefault(template, [])
self._template_attrs[template].append(template_attribute)
@callback
def _handle_results(
self,
event: Event | None,
updates: list[TrackTemplateResult],
) -> None:
"""Call back the results to the attributes."""
if event:
self.async_set_context(event.context)
entity_id = event and event.data.get(ATTR_ENTITY_ID)
if entity_id and entity_id == self.entity_id:
self._self_ref_update_count += 1
else:
self._self_ref_update_count = 0
if self._self_ref_update_count > len(self._template_attrs):
for update in updates:
_LOGGER.warning(
"Template loop detected while processing event: %s, skipping template render for Template[%s]",
event,
update.template.template,
)
return
for update in updates:
for attr in self._template_attrs[update.template]:
attr.handle_result(
event, update.template, update.last_result, update.result
)
self.async_write_ha_state()
async def _async_template_startup(self, *_: Any) -> None:
template_var_tups: list[TrackTemplate] = []
has_availability_template = False
variables = {"this": TemplateStateFromEntityId(self.hass, self.entity_id)}
for template, attributes in self._template_attrs.items():
template_var_tup = TrackTemplate(template, variables)
is_availability_template = False
for attribute in attributes:
# pylint: disable-next=protected-access
if attribute._attribute == "_attr_available":
has_availability_template = True
is_availability_template = True
attribute.async_setup()
# Insert the availability template first in the list
if is_availability_template:
template_var_tups.insert(0, template_var_tup)
else:
template_var_tups.append(template_var_tup)
result_info = async_track_template_result(
self.hass,
template_var_tups,
self._handle_results,
has_super_template=has_availability_template,
)
self.async_on_remove(result_info.async_remove)
self._async_update = result_info.async_refresh
result_info.async_refresh()
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
if self._availability_template is not None:
self.add_template_attribute(
"_attr_available",
self._availability_template,
None,
self._update_available,
)
if self._attribute_templates is not None:
for key, value in self._attribute_templates.items():
self._add_attribute_template(key, value)
if self._icon_template is not None:
self.add_template_attribute(
"_attr_icon", self._icon_template, vol.Or(cv.whitespace, cv.icon)
)
if self._entity_picture_template is not None:
self.add_template_attribute(
"_attr_entity_picture", self._entity_picture_template
)
if (
self._friendly_name_template is not None
and not self._friendly_name_template.is_static
):
self.add_template_attribute("_attr_name", self._friendly_name_template)
if self.hass.state == CoreState.running:
await self._async_template_startup()
return
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, self._async_template_startup
)
async def async_update(self) -> None:
"""Call for forced update."""
assert self._async_update
self._async_update()
async def async_run_script(
self,
script: Script,
*,
run_variables: _VarsType | None = None,
context: Context | None = None,
) -> None:
"""Run an action script."""
if run_variables is None:
run_variables = {}
return await script.async_run(
run_variables={
"this": TemplateStateFromEntityId(self.hass, self.entity_id),
**run_variables,
},
context=context,
)
class TemplateSensor(TemplateEntity, SensorEntity):
"""Representation of a Template Sensor."""
def __init__(
self,
hass: HomeAssistant,
*,
config: dict[str, Any],
fallback_name: str | None,
unique_id: str | None,
) -> None:
"""Initialize the sensor."""
super().__init__(
hass, config=config, fallback_name=fallback_name, unique_id=unique_id
)
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
self._attr_device_class = config.get(CONF_DEVICE_CLASS)
self._attr_state_class = config.get(CONF_STATE_CLASS)

View File

@ -24,6 +24,8 @@ from homeassistant.const import (
STATE_UNKNOWN,
TEMP_CELSIUS,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.setup import async_setup_component
from tests.common import get_fixture_path
@ -864,3 +866,43 @@ async def test_reload(hass):
assert hass.states.get("sensor.mockreset") is None
assert hass.states.get("sensor.rollout")
@respx.mock
async def test_entity_config(hass: HomeAssistant) -> None:
"""Test entity configuration."""
config = {
DOMAIN: {
# REST configuration
"platform": "rest",
"method": "GET",
"resource": "http://localhost",
# Entity configuration
"icon": "{{'mdi:one_two_three'}}",
"picture": "{{'blabla.png'}}",
"device_class": "temperature",
"name": "{{'REST' + ' ' + 'Sensor'}}",
"state_class": "measurement",
"unique_id": "very_unique",
"unit_of_measurement": "beardsecond",
},
}
respx.get("http://localhost") % HTTPStatus.OK
assert await async_setup_component(hass, DOMAIN, config)
await hass.async_block_till_done()
entity_registry = er.async_get(hass)
assert entity_registry.async_get("sensor.rest_sensor").unique_id == "very_unique"
state = hass.states.get("sensor.rest_sensor")
assert state.state == ""
assert state.attributes == {
"device_class": "temperature",
"entity_picture": "blabla.png",
"friendly_name": "REST Sensor",
"icon": "mdi:one_two_three",
"state_class": "measurement",
"unit_of_measurement": "beardsecond",
}