From 3a744d374b00a93ad0e19d58649c0bd43c8443ed Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Fri, 22 Dec 2023 20:02:55 +0100 Subject: [PATCH] Add support for caching entity properties (#100601) --- homeassistant/components/template/image.py | 4 +- homeassistant/components/weather/__init__.py | 6 +- homeassistant/helpers/entity.py | 241 ++++++++++++++++--- tests/components/zha/test_registries.py | 20 +- tests/helpers/test_entity.py | 120 +++++++++ 5 files changed, 345 insertions(+), 46 deletions(-) diff --git a/homeassistant/components/template/image.py b/homeassistant/components/template/image.py index 55a0e2fb72db..227109d59e20 100644 --- a/homeassistant/components/template/image.py +++ b/homeassistant/components/template/image.py @@ -94,9 +94,9 @@ class StateImageEntity(TemplateEntity, ImageEntity): @property def entity_picture(self) -> str | None: """Return entity picture.""" - # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185 if self._entity_picture_template: - return TemplateEntity.entity_picture.fget(self) # type: ignore[attr-defined] + return TemplateEntity.entity_picture.__get__(self) + # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185 return ImageEntity.entity_picture.fget(self) # type: ignore[attr-defined] @callback diff --git a/homeassistant/components/weather/__init__.py b/homeassistant/components/weather/__init__.py index 899181f2b5f6..993c5e9503bb 100644 --- a/homeassistant/components/weather/__init__.py +++ b/homeassistant/components/weather/__init__.py @@ -44,7 +44,7 @@ from homeassistant.helpers.config_validation import ( # noqa: F401 PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE, ) -from homeassistant.helpers.entity import Entity, EntityDescription +from homeassistant.helpers.entity import ABCCachedProperties, Entity, EntityDescription from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_platform import EntityPlatform import homeassistant.helpers.issue_registry as ir @@ -254,10 +254,10 @@ class WeatherEntityDescription(EntityDescription, frozen_or_thawed=True): """A class that describes weather entities.""" -class PostInitMeta(abc.ABCMeta): +class PostInitMeta(ABCCachedProperties): """Meta class which calls __post_init__ after __new__ and __init__.""" - def __call__(cls, *args: Any, **kwargs: Any) -> Any: + def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 ruff bug, ruff does not understand this is a metaclass """Create an instance.""" instance: PostInit = super().__call__(*args, **kwargs) instance.__post_init__(*args, **kwargs) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 9f5ff3dad52a..4fa2d5c3e515 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -1,10 +1,10 @@ """An abstract class for entities.""" from __future__ import annotations -from abc import ABC +from abc import ABCMeta import asyncio from collections import deque -from collections.abc import Coroutine, Iterable, Mapping, MutableMapping +from collections.abc import Callable, Coroutine, Iterable, Mapping, MutableMapping import dataclasses from datetime import timedelta from enum import Enum, auto @@ -26,7 +26,6 @@ from typing import ( import voluptuous as vol -from homeassistant.backports.functools import cached_property from homeassistant.config import DATA_CUSTOMIZE from homeassistant.const import ( ATTR_ASSUMED_STATE, @@ -63,8 +62,11 @@ from .event import ( from .typing import UNDEFINED, EventType, StateType, UndefinedType if TYPE_CHECKING: - from .entity_platform import EntityPlatform + from functools import cached_property + from .entity_platform import EntityPlatform +else: + from homeassistant.backports.functools import cached_property _T = TypeVar("_T") @@ -259,7 +261,177 @@ class CalculatedState: shadowed_attributes: Mapping[str, Any] -class Entity(ABC): +class CachedProperties(type): + """Metaclass which invalidates cached entity properties on write to _attr_. + + A class which has CachedProperties can optionally have a list of cached + properties, passed as cached_properties, which must be a set of strings. + - Each item in the cached_property set must be the name of a method decorated + with @cached_property + - For each item in the cached_property set, a property function with the + same name, prefixed with _attr_, will be created + - The property _attr_-property functions allow setting, getting and deleting + data, which will be stored in an attribute prefixed with __attr_ + - The _attr_-property setter will invalidate the @cached_property by calling + delattr on it + """ + + def __new__( + mcs, # noqa: N804 ruff bug, ruff does not understand this is a metaclass + name: str, + bases: tuple[type, ...], + namespace: dict[Any, Any], + cached_properties: set[str] | None = None, + **kwargs: Any, + ) -> Any: + """Start creating a new CachedProperties. + + Pop cached_properties and store it in the namespace. + """ + namespace["_CachedProperties__cached_properties"] = cached_properties or set() + return super().__new__(mcs, name, bases, namespace) + + def __init__( + cls, + name: str, + bases: tuple[type, ...], + namespace: dict[Any, Any], + **kwargs: Any, + ) -> None: + """Finish creating a new CachedProperties. + + Wrap _attr_ for cached properties in property objects. + """ + + def deleter(name: str) -> Callable[[Any], None]: + """Create a deleter for an _attr_ property.""" + private_attr_name = f"__attr_{name}" + + def _deleter(o: Any) -> None: + """Delete an _attr_ property. + + Does two things: + - Delete the __attr_ attribute + - Invalidate the cache of the cached property + + Raises AttributeError if the __attr_ attribute does not exist + """ + # Invalidate the cache of the cached property + try: # noqa: SIM105 suppress is much slower + delattr(o, name) + except AttributeError: + pass + # Delete the __attr_ attribute + delattr(o, private_attr_name) + + return _deleter + + def getter(name: str) -> Callable[[Any], Any]: + """Create a getter for an _attr_ property.""" + private_attr_name = f"__attr_{name}" + + def _getter(o: Any) -> Any: + """Get an _attr_ property from the backing __attr attribute.""" + return getattr(o, private_attr_name) + + return _getter + + def setter(name: str) -> Callable[[Any, Any], None]: + """Create a setter for an _attr_ property.""" + private_attr_name = f"__attr_{name}" + + def _setter(o: Any, val: Any) -> None: + """Set an _attr_ property to the backing __attr attribute. + + Also invalidates the corresponding cached_property by calling + delattr on it. + """ + setattr(o, private_attr_name, val) + try: # noqa: SIM105 suppress is much slower + delattr(o, name) + except AttributeError: + pass + + return _setter + + def make_property(name: str) -> property: + """Help create a property object.""" + return property(fget=getter(name), fset=setter(name), fdel=deleter(name)) + + def wrap_attr(cls: CachedProperties, property_name: str) -> None: + """Wrap a cached property's corresponding _attr in a property. + + If the class being created has an _attr class attribute, move it, and its + annotations, to the __attr attribute. + """ + attr_name = f"_attr_{property_name}" + private_attr_name = f"__attr_{property_name}" + # Check if an _attr_ class attribute exits and move it to __attr_. We check + # __dict__ here because we don't care about _attr_ class attributes in parents. + if attr_name in cls.__dict__: + setattr(cls, private_attr_name, getattr(cls, attr_name)) + annotations = cls.__annotations__ + if attr_name in annotations: + annotations[private_attr_name] = annotations.pop(attr_name) + # Create the _attr_ property + setattr(cls, attr_name, make_property(property_name)) + + cached_properties: set[str] = namespace["_CachedProperties__cached_properties"] + seen_props: set[str] = set() # Keep track of properties which have been handled + for property_name in cached_properties: + wrap_attr(cls, property_name) + seen_props.add(property_name) + + # Look for cached properties of parent classes where this class has + # corresponding _attr_ class attributes and re-wrap them. + for parent in cls.__mro__[:0:-1]: + if "_CachedProperties__cached_properties" not in parent.__dict__: + continue + cached_properties = getattr(parent, "_CachedProperties__cached_properties") + for property_name in cached_properties: + if property_name in seen_props: + continue + attr_name = f"_attr_{property_name}" + # Check if an _attr_ class attribute exits. We check __dict__ here because + # we don't care about _attr_ class attributes in parents. + if (attr_name) not in cls.__dict__: + continue + wrap_attr(cls, property_name) + seen_props.add(property_name) + + +class ABCCachedProperties(CachedProperties, ABCMeta): + """Add ABCMeta to CachedProperties.""" + + +CACHED_PROPERTIES_WITH_ATTR_ = { + "assumed_state", + "attribution", + "available", + "capability_attributes", + "device_class", + "device_info", + "entity_category", + "has_entity_name", + "entity_picture", + "entity_registry_enabled_default", + "entity_registry_visible_default", + "extra_state_attributes", + "force_update", + "icon", + "name", + "should_poll", + "state", + "supported_features", + "translation_key", + "unique_id", + "unit_of_measurement", +} + + +class Entity( + metaclass=ABCCachedProperties, cached_properties=CACHED_PROPERTIES_WITH_ATTR_ +): """An abstract class for Home Assistant entities.""" # SAFE TO OVERWRITE @@ -367,7 +539,7 @@ class Entity(ABC): cls._entity_component_unrecorded_attributes | cls._unrecorded_attributes ) - @property + @cached_property def should_poll(self) -> bool: """Return True if entity has to be polled for state. @@ -375,7 +547,7 @@ class Entity(ABC): """ return self._attr_should_poll - @property + @cached_property def unique_id(self) -> str | None: """Return a unique ID.""" return self._attr_unique_id @@ -398,7 +570,7 @@ class Entity(ABC): return not self.name - @property + @cached_property def has_entity_name(self) -> bool: """Return if the name of the entity is describing only the entity itself.""" if hasattr(self, "_attr_has_entity_name"): @@ -479,10 +651,17 @@ class Entity(ABC): @property def suggested_object_id(self) -> str | None: """Return input for object id.""" - # The check for self.platform guards against integrations not using an - # EntityComponent and can be removed in HA Core 2024.1 - # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185 - if self.__class__.name.fget is Entity.name.fget and self.platform: # type: ignore[attr-defined] + if ( + # Check our class has overridden the name property from Entity + # We need to use type.__getattribute__ to retrieve the underlying + # property or cached_property object instead of the property's + # value. + type.__getattribute__(self.__class__, "name") + is type.__getattribute__(Entity, "name") + # The check for self.platform guards against integrations not using an + # EntityComponent and can be removed in HA Core 2024.1 + and self.platform + ): name = self._name_internal( self._object_id_device_class_name, self.platform.object_id_platform_translations, @@ -491,7 +670,7 @@ class Entity(ABC): name = self.name return None if name is UNDEFINED else name - @property + @cached_property def name(self) -> str | UndefinedType | None: """Return the name of the entity.""" # The check for self.platform guards against integrations not using an @@ -503,12 +682,12 @@ class Entity(ABC): self.platform.platform_translations, ) - @property + @cached_property def state(self) -> StateType: """Return the state of the entity.""" return self._attr_state - @property + @cached_property def capability_attributes(self) -> Mapping[str, Any] | None: """Return the capability attributes. @@ -531,7 +710,7 @@ class Entity(ABC): """ return None - @property + @cached_property def state_attributes(self) -> dict[str, Any] | None: """Return the state attributes. @@ -540,7 +719,7 @@ class Entity(ABC): """ return None - @property + @cached_property def extra_state_attributes(self) -> Mapping[str, Any] | None: """Return entity specific state attributes. @@ -551,7 +730,7 @@ class Entity(ABC): return self._attr_extra_state_attributes return None - @property + @cached_property def device_info(self) -> DeviceInfo | None: """Return device specific attributes. @@ -559,7 +738,7 @@ class Entity(ABC): """ return self._attr_device_info - @property + @cached_property def device_class(self) -> str | None: """Return the class of this device, from component DEVICE_CLASSES.""" if hasattr(self, "_attr_device_class"): @@ -568,7 +747,7 @@ class Entity(ABC): return self.entity_description.device_class return None - @property + @cached_property def unit_of_measurement(self) -> str | None: """Return the unit of measurement of this entity, if any.""" if hasattr(self, "_attr_unit_of_measurement"): @@ -577,7 +756,7 @@ class Entity(ABC): return self.entity_description.unit_of_measurement return None - @property + @cached_property def icon(self) -> str | None: """Return the icon to use in the frontend, if any.""" if hasattr(self, "_attr_icon"): @@ -586,22 +765,22 @@ class Entity(ABC): return self.entity_description.icon return None - @property + @cached_property def entity_picture(self) -> str | None: """Return the entity picture to use in the frontend, if any.""" return self._attr_entity_picture - @property + @cached_property def available(self) -> bool: """Return True if entity is available.""" return self._attr_available - @property + @cached_property def assumed_state(self) -> bool: """Return True if unable to access real state of the entity.""" return self._attr_assumed_state - @property + @cached_property def force_update(self) -> bool: """Return True if state updates should be forced. @@ -614,12 +793,12 @@ class Entity(ABC): return self.entity_description.force_update return False - @property + @cached_property def supported_features(self) -> int | None: """Flag supported features.""" return self._attr_supported_features - @property + @cached_property def entity_registry_enabled_default(self) -> bool: """Return if the entity should be enabled when first added. @@ -631,7 +810,7 @@ class Entity(ABC): return self.entity_description.entity_registry_enabled_default return True - @property + @cached_property def entity_registry_visible_default(self) -> bool: """Return if the entity should be visible when first added. @@ -643,12 +822,12 @@ class Entity(ABC): return self.entity_description.entity_registry_visible_default return True - @property + @cached_property def attribution(self) -> str | None: """Return the attribution.""" return self._attr_attribution - @property + @cached_property def entity_category(self) -> EntityCategory | None: """Return the category of the entity, if any.""" if hasattr(self, "_attr_entity_category"): @@ -657,7 +836,7 @@ class Entity(ABC): return self.entity_description.entity_category return None - @property + @cached_property def translation_key(self) -> str | None: """Return the translation key to translate the entity's states.""" if hasattr(self, "_attr_translation_key"): diff --git a/tests/components/zha/test_registries.py b/tests/components/zha/test_registries.py index 68ff116adead..80845cf9866f 100644 --- a/tests/components/zha/test_registries.py +++ b/tests/components/zha/test_registries.py @@ -585,18 +585,18 @@ def test_quirk_classes() -> None: def test_entity_names() -> None: """Make sure that all handlers expose entities with valid names.""" - for _, entities in iter_all_rules(): - for entity in entities: - if hasattr(entity, "_attr_name"): + for _, entity_classes in iter_all_rules(): + for entity_class in entity_classes: + if hasattr(entity_class, "__attr_name"): # The entity has a name - assert isinstance(entity._attr_name, str) and entity._attr_name - elif hasattr(entity, "_attr_translation_key"): + assert (name := entity_class.__attr_name) and isinstance(name, str) + elif hasattr(entity_class, "__attr_translation_key"): assert ( - isinstance(entity._attr_translation_key, str) - and entity._attr_translation_key + isinstance(entity_class.__attr_translation_key, str) + and entity_class.__attr_translation_key ) - elif hasattr(entity, "_attr_device_class"): - assert entity._attr_device_class + elif hasattr(entity_class, "__attr_device_class"): + assert entity_class.__attr_device_class else: # The only exception (for now) is IASZone - assert entity is IASZone + assert entity_class is IASZone diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index c3021e397ee2..2bf90660f310 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -13,6 +13,7 @@ import pytest from syrupy.assertion import SnapshotAssertion import voluptuous as vol +from homeassistant.backports.functools import cached_property from homeassistant.const import ( ATTR_ATTRIBUTION, ATTR_DEVICE_CLASS, @@ -1905,3 +1906,122 @@ async def test_update_capabilities_too_often_cooldown( assert entry.supported_features == supported_features + 1 assert capabilities_too_often_warning not in caplog.text + + +@pytest.mark.parametrize( + ("property", "default_value", "values"), [("attribution", None, ["abcd", "efgh"])] +) +async def test_cached_entity_properties( + hass: HomeAssistant, property: str, default_value: Any, values: Any +) -> None: + """Test entity properties are cached.""" + ent1 = entity.Entity() + ent2 = entity.Entity() + assert getattr(ent1, property) == default_value + assert getattr(ent2, property) == default_value + + # Test set + setattr(ent1, f"_attr_{property}", values[0]) + assert getattr(ent1, property) == values[0] + assert getattr(ent2, property) == default_value + + # Test update + setattr(ent1, f"_attr_{property}", values[1]) + assert getattr(ent1, property) == values[1] + assert getattr(ent2, property) == default_value + + # Test delete + delattr(ent1, f"_attr_{property}") + assert getattr(ent1, property) == default_value + assert getattr(ent2, property) == default_value + + +async def test_cached_entity_property_delete_attr(hass: HomeAssistant) -> None: + """Test deleting an _attr corresponding to a cached property.""" + property = "has_entity_name" + + ent = entity.Entity() + assert not hasattr(ent, f"_attr_{property}") + with pytest.raises(AttributeError): + delattr(ent, f"_attr_{property}") + assert getattr(ent, property) is False + + with pytest.raises(AttributeError): + delattr(ent, f"_attr_{property}") + assert not hasattr(ent, f"_attr_{property}") + assert getattr(ent, property) is False + + setattr(ent, f"_attr_{property}", True) + assert getattr(ent, property) is True + + delattr(ent, f"_attr_{property}") + assert not hasattr(ent, f"_attr_{property}") + assert getattr(ent, property) is False + + +async def test_cached_entity_property_class_attribute(hass: HomeAssistant) -> None: + """Test entity properties on class level work in derived classes.""" + property = "attribution" + values = ["abcd", "efgh"] + + class EntityWithClassAttribute1(entity.Entity): + """A derived class which overrides an _attr_ from a parent.""" + + _attr_attribution = values[0] + + class EntityWithClassAttribute2(entity.Entity, cached_properties={property}): + """A derived class which overrides an _attr_ from a parent. + + This class also redundantly marks the overridden _attr_ as cached. + """ + + _attr_attribution = values[0] + + class EntityWithClassAttribute3(entity.Entity, cached_properties={property}): + """A derived class which overrides an _attr_ from a parent. + + This class overrides the attribute property. + """ + + def __init__(self): + self._attr_attribution = values[0] + + @cached_property + def attribution(self) -> str | None: + """Return the attribution.""" + return self._attr_attribution + + class EntityWithClassAttribute4(entity.Entity, cached_properties={property}): + """A derived class which overrides an _attr_ from a parent. + + This class overrides the attribute property and the _attr_. + """ + + _attr_attribution = values[0] + + @cached_property + def attribution(self) -> str | None: + """Return the attribution.""" + return self._attr_attribution + + classes = ( + EntityWithClassAttribute1, + EntityWithClassAttribute2, + EntityWithClassAttribute3, + EntityWithClassAttribute4, + ) + + entities: list[tuple[entity.Entity, entity.Entity]] = [] + for cls in classes: + entities.append((cls(), cls())) + + for ent in entities: + assert getattr(ent[0], property) == values[0] + assert getattr(ent[1], property) == values[0] + + # Test update + for ent in entities: + setattr(ent[0], f"_attr_{property}", values[1]) + for ent in entities: + assert getattr(ent[0], property) == values[1] + assert getattr(ent[1], property) == values[0]