Support restoring SensorEntity native_value (#66068)

This commit is contained in:
Erik Montnemery 2022-02-08 23:00:26 +01:00 committed by GitHub
parent f8a84f0101
commit 009b31941a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 330 additions and 31 deletions

View File

@ -52,7 +52,9 @@ from homeassistant.helpers.config_validation import ( # noqa: F401
)
from homeassistant.helpers.entity import Entity, EntityDescription
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity
from homeassistant.helpers.typing import ConfigType, StateType
from homeassistant.util import dt as dt_util
from .const import CONF_STATE_CLASS # noqa: F401
@ -447,3 +449,62 @@ class SensorEntity(Entity):
return f"<Entity {self.name}>"
return super().__repr__()
@dataclass
class SensorExtraStoredData(ExtraStoredData):
"""Object to hold extra stored data."""
native_value: StateType | date | datetime
native_unit_of_measurement: str | None
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the sensor data."""
native_value: StateType | date | datetime | dict[str, str] = self.native_value
if isinstance(native_value, (date, datetime)):
native_value = {
"__type": str(type(native_value)),
"isoformat": native_value.isoformat(),
}
return {
"native_value": native_value,
"native_unit_of_measurement": self.native_unit_of_measurement,
}
@classmethod
def from_dict(cls, restored: dict[str, Any]) -> SensorExtraStoredData | None:
"""Initialize a stored sensor state from a dict."""
try:
native_value = restored["native_value"]
native_unit_of_measurement = restored["native_unit_of_measurement"]
except KeyError:
return None
try:
type_ = native_value["__type"]
if type_ == "<class 'datetime.datetime'>":
native_value = dt_util.parse_datetime(native_value["isoformat"])
elif type_ == "<class 'datetime.date'>":
native_value = dt_util.parse_date(native_value["isoformat"])
except TypeError:
# native_value is not a dict
pass
except KeyError:
# native_value is a dict, but does not have all values
return None
return cls(native_value, native_unit_of_measurement)
class RestoreSensor(SensorEntity, RestoreEntity):
"""Mixin class for restoring previous sensor state."""
@property
def extra_restore_state_data(self) -> SensorExtraStoredData:
"""Return sensor specific state data to be restored."""
return SensorExtraStoredData(self.native_value, self.native_unit_of_measurement)
async def async_get_last_sensor_data(self) -> SensorExtraStoredData | None:
"""Restore native_value and native_unit_of_measurement."""
if (restored_last_extra_data := await self.async_get_last_extra_data()) is None:
return None
return SensorExtraStoredData.from_dict(restored_last_extra_data.as_dict())

View File

@ -1,6 +1,7 @@
"""Support for restoring entity states on startup."""
from __future__ import annotations
from abc import abstractmethod
import asyncio
from datetime import datetime, timedelta
import logging
@ -34,27 +35,65 @@ STATE_EXPIRATION = timedelta(days=7)
_StoredStateT = TypeVar("_StoredStateT", bound="StoredState")
class ExtraStoredData:
"""Object to hold extra stored data."""
@abstractmethod
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the extra data.
Must be serializable by Home Assistant's JSONEncoder.
"""
class RestoredExtraData(ExtraStoredData):
"""Object to hold extra stored data loaded from storage."""
def __init__(self, json_dict: dict[str, Any]) -> None:
"""Object to hold extra stored data."""
self.json_dict = json_dict
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the extra data."""
return self.json_dict
class StoredState:
"""Object to represent a stored state."""
def __init__(self, state: State, last_seen: datetime) -> None:
def __init__(
self,
state: State,
extra_data: ExtraStoredData | None,
last_seen: datetime,
) -> None:
"""Initialize a new stored state."""
self.state = state
self.extra_data = extra_data
self.last_seen = last_seen
self.state = state
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the stored state."""
return {"state": self.state.as_dict(), "last_seen": self.last_seen}
result = {
"state": self.state.as_dict(),
"extra_data": self.extra_data.as_dict() if self.extra_data else None,
"last_seen": self.last_seen,
}
return result
@classmethod
def from_dict(cls: type[_StoredStateT], json_dict: dict) -> _StoredStateT:
"""Initialize a stored state from a dict."""
extra_data_dict = json_dict.get("extra_data")
extra_data = RestoredExtraData(extra_data_dict) if extra_data_dict else None
last_seen = json_dict["last_seen"]
if isinstance(last_seen, str):
last_seen = dt_util.parse_datetime(last_seen)
return cls(cast(State, State.from_dict(json_dict["state"])), last_seen)
return cls(
cast(State, State.from_dict(json_dict["state"])), extra_data, last_seen
)
class RestoreStateData:
@ -104,7 +143,7 @@ class RestoreStateData:
hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder
)
self.last_states: dict[str, StoredState] = {}
self.entity_ids: set[str] = set()
self.entities: dict[str, RestoreEntity] = {}
@callback
def async_get_stored_states(self) -> list[StoredState]:
@ -125,9 +164,11 @@ class RestoreStateData:
# Start with the currently registered states
stored_states = [
StoredState(state, now)
StoredState(
state, self.entities[state.entity_id].extra_restore_state_data, now
)
for state in all_states
if state.entity_id in self.entity_ids and
if state.entity_id in self.entities and
# Ignore all states that are entity registry placeholders
not state.attributes.get(ATTR_RESTORED)
]
@ -188,12 +229,14 @@ class RestoreStateData:
)
@callback
def async_restore_entity_added(self, entity_id: str) -> None:
def async_restore_entity_added(self, entity: RestoreEntity) -> None:
"""Store this entity's state when hass is shutdown."""
self.entity_ids.add(entity_id)
self.entities[entity.entity_id] = entity
@callback
def async_restore_entity_removed(self, entity_id: str) -> None:
def async_restore_entity_removed(
self, entity_id: str, extra_data: ExtraStoredData | None
) -> None:
"""Unregister this entity from saving state."""
# When an entity is being removed from hass, store its last state. This
# allows us to support state restoration if the entity is removed, then
@ -204,9 +247,11 @@ class RestoreStateData:
if state is not None:
state = State.from_dict(_encode_complex(state.as_dict()))
if state is not None:
self.last_states[entity_id] = StoredState(state, dt_util.utcnow())
self.last_states[entity_id] = StoredState(
state, extra_data, dt_util.utcnow()
)
self.entity_ids.remove(entity_id)
self.entities.pop(entity_id)
def _encode(value: Any) -> Any:
@ -244,7 +289,7 @@ class RestoreEntity(Entity):
super().async_internal_added_to_hass(),
RestoreStateData.async_get_instance(self.hass),
)
data.async_restore_entity_added(self.entity_id)
data.async_restore_entity_added(self)
async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
@ -252,10 +297,10 @@ class RestoreEntity(Entity):
super().async_internal_will_remove_from_hass(),
RestoreStateData.async_get_instance(self.hass),
)
data.async_restore_entity_removed(self.entity_id)
data.async_restore_entity_removed(self.entity_id, self.extra_restore_state_data)
async def async_get_last_state(self) -> State | None:
"""Get the entity state from the previous run."""
async def _async_get_restored_data(self) -> StoredState | None:
"""Get data stored for an entity, if any."""
if self.hass is None or self.entity_id is None:
# Return None if this entity isn't added to hass yet
_LOGGER.warning("Cannot get last state. Entity not added to hass") # type: ignore[unreachable]
@ -265,4 +310,24 @@ class RestoreEntity(Entity):
)
if self.entity_id not in data.last_states:
return None
return data.last_states[self.entity_id].state
return data.last_states[self.entity_id]
async def async_get_last_state(self) -> State | None:
"""Get the entity state from the previous run."""
if (stored_state := await self._async_get_restored_data()) is None:
return None
return stored_state.state
async def async_get_last_extra_data(self) -> ExtraStoredData | None:
"""Get the entity specific state data from the previous run."""
if (stored_state := await self._async_get_restored_data()) is None:
return None
return stored_state.extra_data
@property
def extra_restore_state_data(self) -> ExtraStoredData | None:
"""Return entity specific state data to be restored.
Implemented by platform classes.
"""
return None

View File

@ -44,7 +44,7 @@ from homeassistant.const import (
STATE_OFF,
STATE_ON,
)
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, State
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant
from homeassistant.helpers import (
area_registry,
device_registry,
@ -937,8 +937,33 @@ def mock_restore_cache(hass, states):
json.dumps(restored_state["attributes"], cls=JSONEncoder)
),
}
last_states[state.entity_id] = restore_state.StoredState(
State.from_dict(restored_state), now
last_states[state.entity_id] = restore_state.StoredState.from_dict(
{"state": restored_state, "last_seen": now}
)
data.last_states = last_states
_LOGGER.debug("Restore cache: %s", data.last_states)
assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
hass.data[key] = data
def mock_restore_cache_with_extra_data(hass, states):
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
data = restore_state.RestoreStateData(hass)
now = date_util.utcnow()
last_states = {}
for state, extra_data in states:
restored_state = state.as_dict()
restored_state = {
**restored_state,
"attributes": json.loads(
json.dumps(restored_state["attributes"], cls=JSONEncoder)
),
}
last_states[state.entity_id] = restore_state.StoredState.from_dict(
{"state": restored_state, "extra_data": extra_data, "last_seen": now}
)
data.last_states = last_states
_LOGGER.debug("Restore cache: %s", data.last_states)

View File

@ -11,10 +11,14 @@ from homeassistant.const import (
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
)
from homeassistant.core import State
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
from tests.common import mock_restore_cache_with_extra_data
@pytest.mark.parametrize(
"unit_system,native_unit,state_unit,native_value,state_value",
@ -210,3 +214,131 @@ async def test_reject_timezoneless_datetime_str(
"Invalid datetime: sensor.test provides state '2017-12-19 18:29:42', "
"which is missing timezone information"
) in caplog.text
RESTORE_DATA = {
"str": {"native_unit_of_measurement": "°F", "native_value": "abc123"},
"int": {"native_unit_of_measurement": "°F", "native_value": 123},
"float": {"native_unit_of_measurement": "°F", "native_value": 123.0},
"date": {
"native_unit_of_measurement": "°F",
"native_value": {
"__type": "<class 'datetime.date'>",
"isoformat": date(2020, 2, 8).isoformat(),
},
},
"datetime": {
"native_unit_of_measurement": "°F",
"native_value": {
"__type": "<class 'datetime.datetime'>",
"isoformat": datetime(2020, 2, 8, 15, tzinfo=timezone.utc).isoformat(),
},
},
}
# None | str | int | float | date | datetime:
@pytest.mark.parametrize(
"native_value, native_value_type, expected_extra_data, device_class",
[
("abc123", str, RESTORE_DATA["str"], None),
(123, int, RESTORE_DATA["int"], SensorDeviceClass.TEMPERATURE),
(123.0, float, RESTORE_DATA["float"], SensorDeviceClass.TEMPERATURE),
(date(2020, 2, 8), dict, RESTORE_DATA["date"], SensorDeviceClass.DATE),
(
datetime(2020, 2, 8, 15, tzinfo=timezone.utc),
dict,
RESTORE_DATA["datetime"],
SensorDeviceClass.TIMESTAMP,
),
],
)
async def test_restore_sensor_save_state(
hass,
enable_custom_integrations,
hass_storage,
native_value,
native_value_type,
expected_extra_data,
device_class,
):
"""Test RestoreSensor."""
platform = getattr(hass.components, "test.sensor")
platform.init(empty=True)
platform.ENTITIES["0"] = platform.MockRestoreSensor(
name="Test",
native_value=native_value,
native_unit_of_measurement=TEMP_FAHRENHEIT,
device_class=device_class,
)
entity0 = platform.ENTITIES["0"]
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
await hass.async_block_till_done()
# Trigger saving state
await hass.async_stop()
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]
assert state["entity_id"] == entity0.entity_id
extra_data = hass_storage[RESTORE_STATE_KEY]["data"][0]["extra_data"]
assert extra_data == expected_extra_data
assert type(extra_data["native_value"]) == native_value_type
@pytest.mark.parametrize(
"native_value, native_value_type, extra_data, device_class, uom",
[
("abc123", str, RESTORE_DATA["str"], None, "°F"),
(123, int, RESTORE_DATA["int"], SensorDeviceClass.TEMPERATURE, "°F"),
(123.0, float, RESTORE_DATA["float"], SensorDeviceClass.TEMPERATURE, "°F"),
(date(2020, 2, 8), date, RESTORE_DATA["date"], SensorDeviceClass.DATE, "°F"),
(
datetime(2020, 2, 8, 15, tzinfo=timezone.utc),
datetime,
RESTORE_DATA["datetime"],
SensorDeviceClass.TIMESTAMP,
"°F",
),
(None, type(None), None, None, None),
(None, type(None), {}, None, None),
(None, type(None), {"beer": 123}, None, None),
(
None,
type(None),
{"native_unit_of_measurement": "°F", "native_value": {}},
None,
None,
),
],
)
async def test_restore_sensor_restore_state(
hass,
enable_custom_integrations,
hass_storage,
native_value,
native_value_type,
extra_data,
device_class,
uom,
):
"""Test RestoreSensor."""
mock_restore_cache_with_extra_data(hass, ((State("sensor.test", ""), extra_data),))
platform = getattr(hass.components, "test.sensor")
platform.init(empty=True)
platform.ENTITIES["0"] = platform.MockRestoreSensor(
name="Test",
device_class=device_class,
)
entity0 = platform.ENTITIES["0"]
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
await hass.async_block_till_done()
assert hass.states.get(entity0.entity_id)
assert entity0.native_value == native_value
assert type(entity0.native_value) == native_value_type
assert entity0.native_unit_of_measurement == uom

View File

@ -22,9 +22,9 @@ async def test_caching_data(hass):
"""Test that we cache data."""
now = dt_util.utcnow()
stored_states = [
StoredState(State("input_boolean.b0", "on"), now),
StoredState(State("input_boolean.b1", "on"), now),
StoredState(State("input_boolean.b2", "on"), now),
StoredState(State("input_boolean.b0", "on"), None, now),
StoredState(State("input_boolean.b1", "on"), None, now),
StoredState(State("input_boolean.b2", "on"), None, now),
]
data = await RestoreStateData.async_get_instance(hass)
@ -160,9 +160,9 @@ async def test_hass_starting(hass):
now = dt_util.utcnow()
stored_states = [
StoredState(State("input_boolean.b0", "on"), now),
StoredState(State("input_boolean.b1", "on"), now),
StoredState(State("input_boolean.b2", "on"), now),
StoredState(State("input_boolean.b0", "on"), None, now),
StoredState(State("input_boolean.b1", "on"), None, now),
StoredState(State("input_boolean.b2", "on"), None, now),
]
data = await RestoreStateData.async_get_instance(hass)
@ -225,15 +225,16 @@ async def test_dump_data(hass):
data = await RestoreStateData.async_get_instance(hass)
now = dt_util.utcnow()
data.last_states = {
"input_boolean.b0": StoredState(State("input_boolean.b0", "off"), now),
"input_boolean.b1": StoredState(State("input_boolean.b1", "off"), now),
"input_boolean.b2": StoredState(State("input_boolean.b2", "off"), now),
"input_boolean.b3": StoredState(State("input_boolean.b3", "off"), now),
"input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now),
"input_boolean.b1": StoredState(State("input_boolean.b1", "off"), None, now),
"input_boolean.b2": StoredState(State("input_boolean.b2", "off"), None, now),
"input_boolean.b3": StoredState(State("input_boolean.b3", "off"), None, now),
"input_boolean.b4": StoredState(
State("input_boolean.b4", "off"),
None,
datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC),
),
"input_boolean.b5": StoredState(State("input_boolean.b5", "off"), now),
"input_boolean.b5": StoredState(State("input_boolean.b5", "off"), None, now),
}
with patch(

View File

@ -5,6 +5,7 @@ Call init before using it in your tests to ensure clean test data.
"""
from homeassistant.components.sensor import (
DEVICE_CLASSES,
RestoreSensor,
SensorDeviceClass,
SensorEntity,
)
@ -109,3 +110,17 @@ class MockSensor(MockEntity, SensorEntity):
def state_class(self):
"""Return the state class of this sensor."""
return self._handle("state_class")
class MockRestoreSensor(MockSensor, RestoreSensor):
"""Mock RestoreSensor class."""
async def async_added_to_hass(self) -> None:
"""Restore native_value and native_unit_of_measurement."""
await super().async_added_to_hass()
if (last_sensor_data := await self.async_get_last_sensor_data()) is None:
return
self._values["native_value"] = last_sensor_data.native_value
self._values[
"native_unit_of_measurement"
] = last_sensor_data.native_unit_of_measurement