mirror of https://github.com/home-assistant/core
Support restoring SensorEntity native_value (#66068)
This commit is contained in:
parent
f8a84f0101
commit
009b31941a
|
@ -52,7 +52,9 @@ from homeassistant.helpers.config_validation import ( # noqa: F401
|
|||
)
|
||||
from homeassistant.helpers.entity import Entity, EntityDescription
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity
|
||||
from homeassistant.helpers.typing import ConfigType, StateType
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import CONF_STATE_CLASS # noqa: F401
|
||||
|
||||
|
@ -447,3 +449,62 @@ class SensorEntity(Entity):
|
|||
return f"<Entity {self.name}>"
|
||||
|
||||
return super().__repr__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SensorExtraStoredData(ExtraStoredData):
|
||||
"""Object to hold extra stored data."""
|
||||
|
||||
native_value: StateType | date | datetime
|
||||
native_unit_of_measurement: str | None
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dict representation of the sensor data."""
|
||||
native_value: StateType | date | datetime | dict[str, str] = self.native_value
|
||||
if isinstance(native_value, (date, datetime)):
|
||||
native_value = {
|
||||
"__type": str(type(native_value)),
|
||||
"isoformat": native_value.isoformat(),
|
||||
}
|
||||
return {
|
||||
"native_value": native_value,
|
||||
"native_unit_of_measurement": self.native_unit_of_measurement,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, restored: dict[str, Any]) -> SensorExtraStoredData | None:
|
||||
"""Initialize a stored sensor state from a dict."""
|
||||
try:
|
||||
native_value = restored["native_value"]
|
||||
native_unit_of_measurement = restored["native_unit_of_measurement"]
|
||||
except KeyError:
|
||||
return None
|
||||
try:
|
||||
type_ = native_value["__type"]
|
||||
if type_ == "<class 'datetime.datetime'>":
|
||||
native_value = dt_util.parse_datetime(native_value["isoformat"])
|
||||
elif type_ == "<class 'datetime.date'>":
|
||||
native_value = dt_util.parse_date(native_value["isoformat"])
|
||||
except TypeError:
|
||||
# native_value is not a dict
|
||||
pass
|
||||
except KeyError:
|
||||
# native_value is a dict, but does not have all values
|
||||
return None
|
||||
|
||||
return cls(native_value, native_unit_of_measurement)
|
||||
|
||||
|
||||
class RestoreSensor(SensorEntity, RestoreEntity):
|
||||
"""Mixin class for restoring previous sensor state."""
|
||||
|
||||
@property
|
||||
def extra_restore_state_data(self) -> SensorExtraStoredData:
|
||||
"""Return sensor specific state data to be restored."""
|
||||
return SensorExtraStoredData(self.native_value, self.native_unit_of_measurement)
|
||||
|
||||
async def async_get_last_sensor_data(self) -> SensorExtraStoredData | None:
|
||||
"""Restore native_value and native_unit_of_measurement."""
|
||||
if (restored_last_extra_data := await self.async_get_last_extra_data()) is None:
|
||||
return None
|
||||
return SensorExtraStoredData.from_dict(restored_last_extra_data.as_dict())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Support for restoring entity states on startup."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
@ -34,27 +35,65 @@ STATE_EXPIRATION = timedelta(days=7)
|
|||
_StoredStateT = TypeVar("_StoredStateT", bound="StoredState")
|
||||
|
||||
|
||||
class ExtraStoredData:
|
||||
"""Object to hold extra stored data."""
|
||||
|
||||
@abstractmethod
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dict representation of the extra data.
|
||||
|
||||
Must be serializable by Home Assistant's JSONEncoder.
|
||||
"""
|
||||
|
||||
|
||||
class RestoredExtraData(ExtraStoredData):
|
||||
"""Object to hold extra stored data loaded from storage."""
|
||||
|
||||
def __init__(self, json_dict: dict[str, Any]) -> None:
|
||||
"""Object to hold extra stored data."""
|
||||
self.json_dict = json_dict
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dict representation of the extra data."""
|
||||
return self.json_dict
|
||||
|
||||
|
||||
class StoredState:
|
||||
"""Object to represent a stored state."""
|
||||
|
||||
def __init__(self, state: State, last_seen: datetime) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
state: State,
|
||||
extra_data: ExtraStoredData | None,
|
||||
last_seen: datetime,
|
||||
) -> None:
|
||||
"""Initialize a new stored state."""
|
||||
self.state = state
|
||||
self.extra_data = extra_data
|
||||
self.last_seen = last_seen
|
||||
self.state = state
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return a dict representation of the stored state."""
|
||||
return {"state": self.state.as_dict(), "last_seen": self.last_seen}
|
||||
result = {
|
||||
"state": self.state.as_dict(),
|
||||
"extra_data": self.extra_data.as_dict() if self.extra_data else None,
|
||||
"last_seen": self.last_seen,
|
||||
}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: type[_StoredStateT], json_dict: dict) -> _StoredStateT:
|
||||
"""Initialize a stored state from a dict."""
|
||||
extra_data_dict = json_dict.get("extra_data")
|
||||
extra_data = RestoredExtraData(extra_data_dict) if extra_data_dict else None
|
||||
last_seen = json_dict["last_seen"]
|
||||
|
||||
if isinstance(last_seen, str):
|
||||
last_seen = dt_util.parse_datetime(last_seen)
|
||||
|
||||
return cls(cast(State, State.from_dict(json_dict["state"])), last_seen)
|
||||
return cls(
|
||||
cast(State, State.from_dict(json_dict["state"])), extra_data, last_seen
|
||||
)
|
||||
|
||||
|
||||
class RestoreStateData:
|
||||
|
@ -104,7 +143,7 @@ class RestoreStateData:
|
|||
hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder
|
||||
)
|
||||
self.last_states: dict[str, StoredState] = {}
|
||||
self.entity_ids: set[str] = set()
|
||||
self.entities: dict[str, RestoreEntity] = {}
|
||||
|
||||
@callback
|
||||
def async_get_stored_states(self) -> list[StoredState]:
|
||||
|
@ -125,9 +164,11 @@ class RestoreStateData:
|
|||
|
||||
# Start with the currently registered states
|
||||
stored_states = [
|
||||
StoredState(state, now)
|
||||
StoredState(
|
||||
state, self.entities[state.entity_id].extra_restore_state_data, now
|
||||
)
|
||||
for state in all_states
|
||||
if state.entity_id in self.entity_ids and
|
||||
if state.entity_id in self.entities and
|
||||
# Ignore all states that are entity registry placeholders
|
||||
not state.attributes.get(ATTR_RESTORED)
|
||||
]
|
||||
|
@ -188,12 +229,14 @@ class RestoreStateData:
|
|||
)
|
||||
|
||||
@callback
|
||||
def async_restore_entity_added(self, entity_id: str) -> None:
|
||||
def async_restore_entity_added(self, entity: RestoreEntity) -> None:
|
||||
"""Store this entity's state when hass is shutdown."""
|
||||
self.entity_ids.add(entity_id)
|
||||
self.entities[entity.entity_id] = entity
|
||||
|
||||
@callback
|
||||
def async_restore_entity_removed(self, entity_id: str) -> None:
|
||||
def async_restore_entity_removed(
|
||||
self, entity_id: str, extra_data: ExtraStoredData | None
|
||||
) -> None:
|
||||
"""Unregister this entity from saving state."""
|
||||
# When an entity is being removed from hass, store its last state. This
|
||||
# allows us to support state restoration if the entity is removed, then
|
||||
|
@ -204,9 +247,11 @@ class RestoreStateData:
|
|||
if state is not None:
|
||||
state = State.from_dict(_encode_complex(state.as_dict()))
|
||||
if state is not None:
|
||||
self.last_states[entity_id] = StoredState(state, dt_util.utcnow())
|
||||
self.last_states[entity_id] = StoredState(
|
||||
state, extra_data, dt_util.utcnow()
|
||||
)
|
||||
|
||||
self.entity_ids.remove(entity_id)
|
||||
self.entities.pop(entity_id)
|
||||
|
||||
|
||||
def _encode(value: Any) -> Any:
|
||||
|
@ -244,7 +289,7 @@ class RestoreEntity(Entity):
|
|||
super().async_internal_added_to_hass(),
|
||||
RestoreStateData.async_get_instance(self.hass),
|
||||
)
|
||||
data.async_restore_entity_added(self.entity_id)
|
||||
data.async_restore_entity_added(self)
|
||||
|
||||
async def async_internal_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
|
@ -252,10 +297,10 @@ class RestoreEntity(Entity):
|
|||
super().async_internal_will_remove_from_hass(),
|
||||
RestoreStateData.async_get_instance(self.hass),
|
||||
)
|
||||
data.async_restore_entity_removed(self.entity_id)
|
||||
data.async_restore_entity_removed(self.entity_id, self.extra_restore_state_data)
|
||||
|
||||
async def async_get_last_state(self) -> State | None:
|
||||
"""Get the entity state from the previous run."""
|
||||
async def _async_get_restored_data(self) -> StoredState | None:
|
||||
"""Get data stored for an entity, if any."""
|
||||
if self.hass is None or self.entity_id is None:
|
||||
# Return None if this entity isn't added to hass yet
|
||||
_LOGGER.warning("Cannot get last state. Entity not added to hass") # type: ignore[unreachable]
|
||||
|
@ -265,4 +310,24 @@ class RestoreEntity(Entity):
|
|||
)
|
||||
if self.entity_id not in data.last_states:
|
||||
return None
|
||||
return data.last_states[self.entity_id].state
|
||||
return data.last_states[self.entity_id]
|
||||
|
||||
async def async_get_last_state(self) -> State | None:
|
||||
"""Get the entity state from the previous run."""
|
||||
if (stored_state := await self._async_get_restored_data()) is None:
|
||||
return None
|
||||
return stored_state.state
|
||||
|
||||
async def async_get_last_extra_data(self) -> ExtraStoredData | None:
|
||||
"""Get the entity specific state data from the previous run."""
|
||||
if (stored_state := await self._async_get_restored_data()) is None:
|
||||
return None
|
||||
return stored_state.extra_data
|
||||
|
||||
@property
|
||||
def extra_restore_state_data(self) -> ExtraStoredData | None:
|
||||
"""Return entity specific state data to be restored.
|
||||
|
||||
Implemented by platform classes.
|
||||
"""
|
||||
return None
|
||||
|
|
|
@ -44,7 +44,7 @@ from homeassistant.const import (
|
|||
STATE_OFF,
|
||||
STATE_ON,
|
||||
)
|
||||
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, State
|
||||
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant
|
||||
from homeassistant.helpers import (
|
||||
area_registry,
|
||||
device_registry,
|
||||
|
@ -937,8 +937,33 @@ def mock_restore_cache(hass, states):
|
|||
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
||||
),
|
||||
}
|
||||
last_states[state.entity_id] = restore_state.StoredState(
|
||||
State.from_dict(restored_state), now
|
||||
last_states[state.entity_id] = restore_state.StoredState.from_dict(
|
||||
{"state": restored_state, "last_seen": now}
|
||||
)
|
||||
data.last_states = last_states
|
||||
_LOGGER.debug("Restore cache: %s", data.last_states)
|
||||
assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
|
||||
|
||||
hass.data[key] = data
|
||||
|
||||
|
||||
def mock_restore_cache_with_extra_data(hass, states):
|
||||
"""Mock the DATA_RESTORE_CACHE."""
|
||||
key = restore_state.DATA_RESTORE_STATE_TASK
|
||||
data = restore_state.RestoreStateData(hass)
|
||||
now = date_util.utcnow()
|
||||
|
||||
last_states = {}
|
||||
for state, extra_data in states:
|
||||
restored_state = state.as_dict()
|
||||
restored_state = {
|
||||
**restored_state,
|
||||
"attributes": json.loads(
|
||||
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
||||
),
|
||||
}
|
||||
last_states[state.entity_id] = restore_state.StoredState.from_dict(
|
||||
{"state": restored_state, "extra_data": extra_data, "last_seen": now}
|
||||
)
|
||||
data.last_states = last_states
|
||||
_LOGGER.debug("Restore cache: %s", data.last_states)
|
||||
|
|
|
@ -11,10 +11,14 @@ from homeassistant.const import (
|
|||
TEMP_CELSIUS,
|
||||
TEMP_FAHRENHEIT,
|
||||
)
|
||||
from homeassistant.core import State
|
||||
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
|
||||
|
||||
from tests.common import mock_restore_cache_with_extra_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"unit_system,native_unit,state_unit,native_value,state_value",
|
||||
|
@ -210,3 +214,131 @@ async def test_reject_timezoneless_datetime_str(
|
|||
"Invalid datetime: sensor.test provides state '2017-12-19 18:29:42', "
|
||||
"which is missing timezone information"
|
||||
) in caplog.text
|
||||
|
||||
|
||||
RESTORE_DATA = {
|
||||
"str": {"native_unit_of_measurement": "°F", "native_value": "abc123"},
|
||||
"int": {"native_unit_of_measurement": "°F", "native_value": 123},
|
||||
"float": {"native_unit_of_measurement": "°F", "native_value": 123.0},
|
||||
"date": {
|
||||
"native_unit_of_measurement": "°F",
|
||||
"native_value": {
|
||||
"__type": "<class 'datetime.date'>",
|
||||
"isoformat": date(2020, 2, 8).isoformat(),
|
||||
},
|
||||
},
|
||||
"datetime": {
|
||||
"native_unit_of_measurement": "°F",
|
||||
"native_value": {
|
||||
"__type": "<class 'datetime.datetime'>",
|
||||
"isoformat": datetime(2020, 2, 8, 15, tzinfo=timezone.utc).isoformat(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# None | str | int | float | date | datetime:
|
||||
@pytest.mark.parametrize(
|
||||
"native_value, native_value_type, expected_extra_data, device_class",
|
||||
[
|
||||
("abc123", str, RESTORE_DATA["str"], None),
|
||||
(123, int, RESTORE_DATA["int"], SensorDeviceClass.TEMPERATURE),
|
||||
(123.0, float, RESTORE_DATA["float"], SensorDeviceClass.TEMPERATURE),
|
||||
(date(2020, 2, 8), dict, RESTORE_DATA["date"], SensorDeviceClass.DATE),
|
||||
(
|
||||
datetime(2020, 2, 8, 15, tzinfo=timezone.utc),
|
||||
dict,
|
||||
RESTORE_DATA["datetime"],
|
||||
SensorDeviceClass.TIMESTAMP,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_restore_sensor_save_state(
|
||||
hass,
|
||||
enable_custom_integrations,
|
||||
hass_storage,
|
||||
native_value,
|
||||
native_value_type,
|
||||
expected_extra_data,
|
||||
device_class,
|
||||
):
|
||||
"""Test RestoreSensor."""
|
||||
platform = getattr(hass.components, "test.sensor")
|
||||
platform.init(empty=True)
|
||||
platform.ENTITIES["0"] = platform.MockRestoreSensor(
|
||||
name="Test",
|
||||
native_value=native_value,
|
||||
native_unit_of_measurement=TEMP_FAHRENHEIT,
|
||||
device_class=device_class,
|
||||
)
|
||||
|
||||
entity0 = platform.ENTITIES["0"]
|
||||
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Trigger saving state
|
||||
await hass.async_stop()
|
||||
|
||||
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
|
||||
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]
|
||||
assert state["entity_id"] == entity0.entity_id
|
||||
extra_data = hass_storage[RESTORE_STATE_KEY]["data"][0]["extra_data"]
|
||||
assert extra_data == expected_extra_data
|
||||
assert type(extra_data["native_value"]) == native_value_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"native_value, native_value_type, extra_data, device_class, uom",
|
||||
[
|
||||
("abc123", str, RESTORE_DATA["str"], None, "°F"),
|
||||
(123, int, RESTORE_DATA["int"], SensorDeviceClass.TEMPERATURE, "°F"),
|
||||
(123.0, float, RESTORE_DATA["float"], SensorDeviceClass.TEMPERATURE, "°F"),
|
||||
(date(2020, 2, 8), date, RESTORE_DATA["date"], SensorDeviceClass.DATE, "°F"),
|
||||
(
|
||||
datetime(2020, 2, 8, 15, tzinfo=timezone.utc),
|
||||
datetime,
|
||||
RESTORE_DATA["datetime"],
|
||||
SensorDeviceClass.TIMESTAMP,
|
||||
"°F",
|
||||
),
|
||||
(None, type(None), None, None, None),
|
||||
(None, type(None), {}, None, None),
|
||||
(None, type(None), {"beer": 123}, None, None),
|
||||
(
|
||||
None,
|
||||
type(None),
|
||||
{"native_unit_of_measurement": "°F", "native_value": {}},
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_restore_sensor_restore_state(
|
||||
hass,
|
||||
enable_custom_integrations,
|
||||
hass_storage,
|
||||
native_value,
|
||||
native_value_type,
|
||||
extra_data,
|
||||
device_class,
|
||||
uom,
|
||||
):
|
||||
"""Test RestoreSensor."""
|
||||
mock_restore_cache_with_extra_data(hass, ((State("sensor.test", ""), extra_data),))
|
||||
|
||||
platform = getattr(hass.components, "test.sensor")
|
||||
platform.init(empty=True)
|
||||
platform.ENTITIES["0"] = platform.MockRestoreSensor(
|
||||
name="Test",
|
||||
device_class=device_class,
|
||||
)
|
||||
|
||||
entity0 = platform.ENTITIES["0"]
|
||||
assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.states.get(entity0.entity_id)
|
||||
|
||||
assert entity0.native_value == native_value
|
||||
assert type(entity0.native_value) == native_value_type
|
||||
assert entity0.native_unit_of_measurement == uom
|
||||
|
|
|
@ -22,9 +22,9 @@ async def test_caching_data(hass):
|
|||
"""Test that we cache data."""
|
||||
now = dt_util.utcnow()
|
||||
stored_states = [
|
||||
StoredState(State("input_boolean.b0", "on"), now),
|
||||
StoredState(State("input_boolean.b1", "on"), now),
|
||||
StoredState(State("input_boolean.b2", "on"), now),
|
||||
StoredState(State("input_boolean.b0", "on"), None, now),
|
||||
StoredState(State("input_boolean.b1", "on"), None, now),
|
||||
StoredState(State("input_boolean.b2", "on"), None, now),
|
||||
]
|
||||
|
||||
data = await RestoreStateData.async_get_instance(hass)
|
||||
|
@ -160,9 +160,9 @@ async def test_hass_starting(hass):
|
|||
|
||||
now = dt_util.utcnow()
|
||||
stored_states = [
|
||||
StoredState(State("input_boolean.b0", "on"), now),
|
||||
StoredState(State("input_boolean.b1", "on"), now),
|
||||
StoredState(State("input_boolean.b2", "on"), now),
|
||||
StoredState(State("input_boolean.b0", "on"), None, now),
|
||||
StoredState(State("input_boolean.b1", "on"), None, now),
|
||||
StoredState(State("input_boolean.b2", "on"), None, now),
|
||||
]
|
||||
|
||||
data = await RestoreStateData.async_get_instance(hass)
|
||||
|
@ -225,15 +225,16 @@ async def test_dump_data(hass):
|
|||
data = await RestoreStateData.async_get_instance(hass)
|
||||
now = dt_util.utcnow()
|
||||
data.last_states = {
|
||||
"input_boolean.b0": StoredState(State("input_boolean.b0", "off"), now),
|
||||
"input_boolean.b1": StoredState(State("input_boolean.b1", "off"), now),
|
||||
"input_boolean.b2": StoredState(State("input_boolean.b2", "off"), now),
|
||||
"input_boolean.b3": StoredState(State("input_boolean.b3", "off"), now),
|
||||
"input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now),
|
||||
"input_boolean.b1": StoredState(State("input_boolean.b1", "off"), None, now),
|
||||
"input_boolean.b2": StoredState(State("input_boolean.b2", "off"), None, now),
|
||||
"input_boolean.b3": StoredState(State("input_boolean.b3", "off"), None, now),
|
||||
"input_boolean.b4": StoredState(
|
||||
State("input_boolean.b4", "off"),
|
||||
None,
|
||||
datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC),
|
||||
),
|
||||
"input_boolean.b5": StoredState(State("input_boolean.b5", "off"), now),
|
||||
"input_boolean.b5": StoredState(State("input_boolean.b5", "off"), None, now),
|
||||
}
|
||||
|
||||
with patch(
|
||||
|
|
|
@ -5,6 +5,7 @@ Call init before using it in your tests to ensure clean test data.
|
|||
"""
|
||||
from homeassistant.components.sensor import (
|
||||
DEVICE_CLASSES,
|
||||
RestoreSensor,
|
||||
SensorDeviceClass,
|
||||
SensorEntity,
|
||||
)
|
||||
|
@ -109,3 +110,17 @@ class MockSensor(MockEntity, SensorEntity):
|
|||
def state_class(self):
|
||||
"""Return the state class of this sensor."""
|
||||
return self._handle("state_class")
|
||||
|
||||
|
||||
class MockRestoreSensor(MockSensor, RestoreSensor):
|
||||
"""Mock RestoreSensor class."""
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Restore native_value and native_unit_of_measurement."""
|
||||
await super().async_added_to_hass()
|
||||
if (last_sensor_data := await self.async_get_last_sensor_data()) is None:
|
||||
return
|
||||
self._values["native_value"] = last_sensor_data.native_value
|
||||
self._values[
|
||||
"native_unit_of_measurement"
|
||||
] = last_sensor_data.native_unit_of_measurement
|
||||
|
|
Loading…
Reference in New Issue