1
mirror of https://github.com/home-assistant/core synced 2024-07-15 09:42:11 +02:00

Add support for caching entity properties (#100601)

This commit is contained in:
Erik Montnemery 2023-12-22 20:02:55 +01:00 committed by GitHub
parent 087eb86e37
commit 3a744d374b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 345 additions and 46 deletions

View File

@ -94,9 +94,9 @@ class StateImageEntity(TemplateEntity, ImageEntity):
@property
def entity_picture(self) -> str | None:
"""Return entity picture."""
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
if self._entity_picture_template:
return TemplateEntity.entity_picture.fget(self) # type: ignore[attr-defined]
return TemplateEntity.entity_picture.__get__(self)
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
return ImageEntity.entity_picture.fget(self) # type: ignore[attr-defined]
@callback

View File

@ -44,7 +44,7 @@ from homeassistant.helpers.config_validation import ( # noqa: F401
PLATFORM_SCHEMA,
PLATFORM_SCHEMA_BASE,
)
from homeassistant.helpers.entity import Entity, EntityDescription
from homeassistant.helpers.entity import ABCCachedProperties, Entity, EntityDescription
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.entity_platform import EntityPlatform
import homeassistant.helpers.issue_registry as ir
@ -254,10 +254,10 @@ class WeatherEntityDescription(EntityDescription, frozen_or_thawed=True):
"""A class that describes weather entities."""
class PostInitMeta(abc.ABCMeta):
class PostInitMeta(ABCCachedProperties):
"""Meta class which calls __post_init__ after __new__ and __init__."""
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 ruff bug, ruff does not understand this is a metaclass
"""Create an instance."""
instance: PostInit = super().__call__(*args, **kwargs)
instance.__post_init__(*args, **kwargs)

View File

@ -1,10 +1,10 @@
"""An abstract class for entities."""
from __future__ import annotations
from abc import ABC
from abc import ABCMeta
import asyncio
from collections import deque
from collections.abc import Coroutine, Iterable, Mapping, MutableMapping
from collections.abc import Callable, Coroutine, Iterable, Mapping, MutableMapping
import dataclasses
from datetime import timedelta
from enum import Enum, auto
@ -26,7 +26,6 @@ from typing import (
import voluptuous as vol
from homeassistant.backports.functools import cached_property
from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.const import (
ATTR_ASSUMED_STATE,
@ -63,8 +62,11 @@ from .event import (
from .typing import UNDEFINED, EventType, StateType, UndefinedType
if TYPE_CHECKING:
from .entity_platform import EntityPlatform
from functools import cached_property
from .entity_platform import EntityPlatform
else:
from homeassistant.backports.functools import cached_property
_T = TypeVar("_T")
@ -259,7 +261,177 @@ class CalculatedState:
shadowed_attributes: Mapping[str, Any]
class Entity(ABC):
class CachedProperties(type):
"""Metaclass which invalidates cached entity properties on write to _attr_.
A class which has CachedProperties can optionally have a list of cached
properties, passed as cached_properties, which must be a set of strings.
- Each item in the cached_property set must be the name of a method decorated
with @cached_property
- For each item in the cached_property set, a property function with the
same name, prefixed with _attr_, will be created
- The property _attr_-property functions allow setting, getting and deleting
data, which will be stored in an attribute prefixed with __attr_
- The _attr_-property setter will invalidate the @cached_property by calling
delattr on it
"""
def __new__(
mcs, # noqa: N804 ruff bug, ruff does not understand this is a metaclass
name: str,
bases: tuple[type, ...],
namespace: dict[Any, Any],
cached_properties: set[str] | None = None,
**kwargs: Any,
) -> Any:
"""Start creating a new CachedProperties.
Pop cached_properties and store it in the namespace.
"""
namespace["_CachedProperties__cached_properties"] = cached_properties or set()
return super().__new__(mcs, name, bases, namespace)
def __init__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[Any, Any],
**kwargs: Any,
) -> None:
"""Finish creating a new CachedProperties.
Wrap _attr_ for cached properties in property objects.
"""
def deleter(name: str) -> Callable[[Any], None]:
"""Create a deleter for an _attr_ property."""
private_attr_name = f"__attr_{name}"
def _deleter(o: Any) -> None:
"""Delete an _attr_ property.
Does two things:
- Delete the __attr_ attribute
- Invalidate the cache of the cached property
Raises AttributeError if the __attr_ attribute does not exist
"""
# Invalidate the cache of the cached property
try: # noqa: SIM105 suppress is much slower
delattr(o, name)
except AttributeError:
pass
# Delete the __attr_ attribute
delattr(o, private_attr_name)
return _deleter
def getter(name: str) -> Callable[[Any], Any]:
"""Create a getter for an _attr_ property."""
private_attr_name = f"__attr_{name}"
def _getter(o: Any) -> Any:
"""Get an _attr_ property from the backing __attr attribute."""
return getattr(o, private_attr_name)
return _getter
def setter(name: str) -> Callable[[Any, Any], None]:
"""Create a setter for an _attr_ property."""
private_attr_name = f"__attr_{name}"
def _setter(o: Any, val: Any) -> None:
"""Set an _attr_ property to the backing __attr attribute.
Also invalidates the corresponding cached_property by calling
delattr on it.
"""
setattr(o, private_attr_name, val)
try: # noqa: SIM105 suppress is much slower
delattr(o, name)
except AttributeError:
pass
return _setter
def make_property(name: str) -> property:
"""Help create a property object."""
return property(fget=getter(name), fset=setter(name), fdel=deleter(name))
def wrap_attr(cls: CachedProperties, property_name: str) -> None:
"""Wrap a cached property's corresponding _attr in a property.
If the class being created has an _attr class attribute, move it, and its
annotations, to the __attr attribute.
"""
attr_name = f"_attr_{property_name}"
private_attr_name = f"__attr_{property_name}"
# Check if an _attr_ class attribute exits and move it to __attr_. We check
# __dict__ here because we don't care about _attr_ class attributes in parents.
if attr_name in cls.__dict__:
setattr(cls, private_attr_name, getattr(cls, attr_name))
annotations = cls.__annotations__
if attr_name in annotations:
annotations[private_attr_name] = annotations.pop(attr_name)
# Create the _attr_ property
setattr(cls, attr_name, make_property(property_name))
cached_properties: set[str] = namespace["_CachedProperties__cached_properties"]
seen_props: set[str] = set() # Keep track of properties which have been handled
for property_name in cached_properties:
wrap_attr(cls, property_name)
seen_props.add(property_name)
# Look for cached properties of parent classes where this class has
# corresponding _attr_ class attributes and re-wrap them.
for parent in cls.__mro__[:0:-1]:
if "_CachedProperties__cached_properties" not in parent.__dict__:
continue
cached_properties = getattr(parent, "_CachedProperties__cached_properties")
for property_name in cached_properties:
if property_name in seen_props:
continue
attr_name = f"_attr_{property_name}"
# Check if an _attr_ class attribute exits. We check __dict__ here because
# we don't care about _attr_ class attributes in parents.
if (attr_name) not in cls.__dict__:
continue
wrap_attr(cls, property_name)
seen_props.add(property_name)
class ABCCachedProperties(CachedProperties, ABCMeta):
"""Add ABCMeta to CachedProperties."""
CACHED_PROPERTIES_WITH_ATTR_ = {
"assumed_state",
"attribution",
"available",
"capability_attributes",
"device_class",
"device_info",
"entity_category",
"has_entity_name",
"entity_picture",
"entity_registry_enabled_default",
"entity_registry_visible_default",
"extra_state_attributes",
"force_update",
"icon",
"name",
"should_poll",
"state",
"supported_features",
"translation_key",
"unique_id",
"unit_of_measurement",
}
class Entity(
metaclass=ABCCachedProperties, cached_properties=CACHED_PROPERTIES_WITH_ATTR_
):
"""An abstract class for Home Assistant entities."""
# SAFE TO OVERWRITE
@ -367,7 +539,7 @@ class Entity(ABC):
cls._entity_component_unrecorded_attributes | cls._unrecorded_attributes
)
@property
@cached_property
def should_poll(self) -> bool:
"""Return True if entity has to be polled for state.
@ -375,7 +547,7 @@ class Entity(ABC):
"""
return self._attr_should_poll
@property
@cached_property
def unique_id(self) -> str | None:
"""Return a unique ID."""
return self._attr_unique_id
@ -398,7 +570,7 @@ class Entity(ABC):
return not self.name
@property
@cached_property
def has_entity_name(self) -> bool:
"""Return if the name of the entity is describing only the entity itself."""
if hasattr(self, "_attr_has_entity_name"):
@ -479,10 +651,17 @@ class Entity(ABC):
@property
def suggested_object_id(self) -> str | None:
"""Return input for object id."""
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
if self.__class__.name.fget is Entity.name.fget and self.platform: # type: ignore[attr-defined]
if (
# Check our class has overridden the name property from Entity
# We need to use type.__getattribute__ to retrieve the underlying
# property or cached_property object instead of the property's
# value.
type.__getattribute__(self.__class__, "name")
is type.__getattribute__(Entity, "name")
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1
and self.platform
):
name = self._name_internal(
self._object_id_device_class_name,
self.platform.object_id_platform_translations,
@ -491,7 +670,7 @@ class Entity(ABC):
name = self.name
return None if name is UNDEFINED else name
@property
@cached_property
def name(self) -> str | UndefinedType | None:
"""Return the name of the entity."""
# The check for self.platform guards against integrations not using an
@ -503,12 +682,12 @@ class Entity(ABC):
self.platform.platform_translations,
)
@property
@cached_property
def state(self) -> StateType:
"""Return the state of the entity."""
return self._attr_state
@property
@cached_property
def capability_attributes(self) -> Mapping[str, Any] | None:
"""Return the capability attributes.
@ -531,7 +710,7 @@ class Entity(ABC):
"""
return None
@property
@cached_property
def state_attributes(self) -> dict[str, Any] | None:
"""Return the state attributes.
@ -540,7 +719,7 @@ class Entity(ABC):
"""
return None
@property
@cached_property
def extra_state_attributes(self) -> Mapping[str, Any] | None:
"""Return entity specific state attributes.
@ -551,7 +730,7 @@ class Entity(ABC):
return self._attr_extra_state_attributes
return None
@property
@cached_property
def device_info(self) -> DeviceInfo | None:
"""Return device specific attributes.
@ -559,7 +738,7 @@ class Entity(ABC):
"""
return self._attr_device_info
@property
@cached_property
def device_class(self) -> str | None:
"""Return the class of this device, from component DEVICE_CLASSES."""
if hasattr(self, "_attr_device_class"):
@ -568,7 +747,7 @@ class Entity(ABC):
return self.entity_description.device_class
return None
@property
@cached_property
def unit_of_measurement(self) -> str | None:
"""Return the unit of measurement of this entity, if any."""
if hasattr(self, "_attr_unit_of_measurement"):
@ -577,7 +756,7 @@ class Entity(ABC):
return self.entity_description.unit_of_measurement
return None
@property
@cached_property
def icon(self) -> str | None:
"""Return the icon to use in the frontend, if any."""
if hasattr(self, "_attr_icon"):
@ -586,22 +765,22 @@ class Entity(ABC):
return self.entity_description.icon
return None
@property
@cached_property
def entity_picture(self) -> str | None:
"""Return the entity picture to use in the frontend, if any."""
return self._attr_entity_picture
@property
@cached_property
def available(self) -> bool:
"""Return True if entity is available."""
return self._attr_available
@property
@cached_property
def assumed_state(self) -> bool:
"""Return True if unable to access real state of the entity."""
return self._attr_assumed_state
@property
@cached_property
def force_update(self) -> bool:
"""Return True if state updates should be forced.
@ -614,12 +793,12 @@ class Entity(ABC):
return self.entity_description.force_update
return False
@property
@cached_property
def supported_features(self) -> int | None:
"""Flag supported features."""
return self._attr_supported_features
@property
@cached_property
def entity_registry_enabled_default(self) -> bool:
"""Return if the entity should be enabled when first added.
@ -631,7 +810,7 @@ class Entity(ABC):
return self.entity_description.entity_registry_enabled_default
return True
@property
@cached_property
def entity_registry_visible_default(self) -> bool:
"""Return if the entity should be visible when first added.
@ -643,12 +822,12 @@ class Entity(ABC):
return self.entity_description.entity_registry_visible_default
return True
@property
@cached_property
def attribution(self) -> str | None:
"""Return the attribution."""
return self._attr_attribution
@property
@cached_property
def entity_category(self) -> EntityCategory | None:
"""Return the category of the entity, if any."""
if hasattr(self, "_attr_entity_category"):
@ -657,7 +836,7 @@ class Entity(ABC):
return self.entity_description.entity_category
return None
@property
@cached_property
def translation_key(self) -> str | None:
"""Return the translation key to translate the entity's states."""
if hasattr(self, "_attr_translation_key"):

View File

@ -585,18 +585,18 @@ def test_quirk_classes() -> None:
def test_entity_names() -> None:
"""Make sure that all handlers expose entities with valid names."""
for _, entities in iter_all_rules():
for entity in entities:
if hasattr(entity, "_attr_name"):
for _, entity_classes in iter_all_rules():
for entity_class in entity_classes:
if hasattr(entity_class, "__attr_name"):
# The entity has a name
assert isinstance(entity._attr_name, str) and entity._attr_name
elif hasattr(entity, "_attr_translation_key"):
assert (name := entity_class.__attr_name) and isinstance(name, str)
elif hasattr(entity_class, "__attr_translation_key"):
assert (
isinstance(entity._attr_translation_key, str)
and entity._attr_translation_key
isinstance(entity_class.__attr_translation_key, str)
and entity_class.__attr_translation_key
)
elif hasattr(entity, "_attr_device_class"):
assert entity._attr_device_class
elif hasattr(entity_class, "__attr_device_class"):
assert entity_class.__attr_device_class
else:
# The only exception (for now) is IASZone
assert entity is IASZone
assert entity_class is IASZone

View File

@ -13,6 +13,7 @@ import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.backports.functools import cached_property
from homeassistant.const import (
ATTR_ATTRIBUTION,
ATTR_DEVICE_CLASS,
@ -1905,3 +1906,122 @@ async def test_update_capabilities_too_often_cooldown(
assert entry.supported_features == supported_features + 1
assert capabilities_too_often_warning not in caplog.text
@pytest.mark.parametrize(
("property", "default_value", "values"), [("attribution", None, ["abcd", "efgh"])]
)
async def test_cached_entity_properties(
hass: HomeAssistant, property: str, default_value: Any, values: Any
) -> None:
"""Test entity properties are cached."""
ent1 = entity.Entity()
ent2 = entity.Entity()
assert getattr(ent1, property) == default_value
assert getattr(ent2, property) == default_value
# Test set
setattr(ent1, f"_attr_{property}", values[0])
assert getattr(ent1, property) == values[0]
assert getattr(ent2, property) == default_value
# Test update
setattr(ent1, f"_attr_{property}", values[1])
assert getattr(ent1, property) == values[1]
assert getattr(ent2, property) == default_value
# Test delete
delattr(ent1, f"_attr_{property}")
assert getattr(ent1, property) == default_value
assert getattr(ent2, property) == default_value
async def test_cached_entity_property_delete_attr(hass: HomeAssistant) -> None:
"""Test deleting an _attr corresponding to a cached property."""
property = "has_entity_name"
ent = entity.Entity()
assert not hasattr(ent, f"_attr_{property}")
with pytest.raises(AttributeError):
delattr(ent, f"_attr_{property}")
assert getattr(ent, property) is False
with pytest.raises(AttributeError):
delattr(ent, f"_attr_{property}")
assert not hasattr(ent, f"_attr_{property}")
assert getattr(ent, property) is False
setattr(ent, f"_attr_{property}", True)
assert getattr(ent, property) is True
delattr(ent, f"_attr_{property}")
assert not hasattr(ent, f"_attr_{property}")
assert getattr(ent, property) is False
async def test_cached_entity_property_class_attribute(hass: HomeAssistant) -> None:
"""Test entity properties on class level work in derived classes."""
property = "attribution"
values = ["abcd", "efgh"]
class EntityWithClassAttribute1(entity.Entity):
"""A derived class which overrides an _attr_ from a parent."""
_attr_attribution = values[0]
class EntityWithClassAttribute2(entity.Entity, cached_properties={property}):
"""A derived class which overrides an _attr_ from a parent.
This class also redundantly marks the overridden _attr_ as cached.
"""
_attr_attribution = values[0]
class EntityWithClassAttribute3(entity.Entity, cached_properties={property}):
"""A derived class which overrides an _attr_ from a parent.
This class overrides the attribute property.
"""
def __init__(self):
self._attr_attribution = values[0]
@cached_property
def attribution(self) -> str | None:
"""Return the attribution."""
return self._attr_attribution
class EntityWithClassAttribute4(entity.Entity, cached_properties={property}):
"""A derived class which overrides an _attr_ from a parent.
This class overrides the attribute property and the _attr_.
"""
_attr_attribution = values[0]
@cached_property
def attribution(self) -> str | None:
"""Return the attribution."""
return self._attr_attribution
classes = (
EntityWithClassAttribute1,
EntityWithClassAttribute2,
EntityWithClassAttribute3,
EntityWithClassAttribute4,
)
entities: list[tuple[entity.Entity, entity.Entity]] = []
for cls in classes:
entities.append((cls(), cls()))
for ent in entities:
assert getattr(ent[0], property) == values[0]
assert getattr(ent[1], property) == values[0]
# Test update
for ent in entities:
setattr(ent[0], f"_attr_{property}", values[1])
for ent in entities:
assert getattr(ent[0], property) == values[1]
assert getattr(ent[1], property) == values[0]