diff --git a/homeassistant/components/sensor/__init__.py b/homeassistant/components/sensor/__init__.py index 38461bce859e..c6c4d18d21d8 100644 --- a/homeassistant/components/sensor/__init__.py +++ b/homeassistant/components/sensor/__init__.py @@ -52,7 +52,9 @@ from homeassistant.helpers.config_validation import ( # noqa: F401 ) from homeassistant.helpers.entity import Entity, EntityDescription from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.restore_state import ExtraStoredData, RestoreEntity from homeassistant.helpers.typing import ConfigType, StateType +from homeassistant.util import dt as dt_util from .const import CONF_STATE_CLASS # noqa: F401 @@ -447,3 +449,62 @@ class SensorEntity(Entity): return f"" return super().__repr__() + + +@dataclass +class SensorExtraStoredData(ExtraStoredData): + """Object to hold extra stored data.""" + + native_value: StateType | date | datetime + native_unit_of_measurement: str | None + + def as_dict(self) -> dict[str, Any]: + """Return a dict representation of the sensor data.""" + native_value: StateType | date | datetime | dict[str, str] = self.native_value + if isinstance(native_value, (date, datetime)): + native_value = { + "__type": str(type(native_value)), + "isoformat": native_value.isoformat(), + } + return { + "native_value": native_value, + "native_unit_of_measurement": self.native_unit_of_measurement, + } + + @classmethod + def from_dict(cls, restored: dict[str, Any]) -> SensorExtraStoredData | None: + """Initialize a stored sensor state from a dict.""" + try: + native_value = restored["native_value"] + native_unit_of_measurement = restored["native_unit_of_measurement"] + except KeyError: + return None + try: + type_ = native_value["__type"] + if type_ == "": + native_value = dt_util.parse_datetime(native_value["isoformat"]) + elif type_ == "": + native_value = dt_util.parse_date(native_value["isoformat"]) + except TypeError: + # native_value is not a dict + pass + except KeyError: + # native_value is a dict, but does not have all values + return None + + return cls(native_value, native_unit_of_measurement) + + +class RestoreSensor(SensorEntity, RestoreEntity): + """Mixin class for restoring previous sensor state.""" + + @property + def extra_restore_state_data(self) -> SensorExtraStoredData: + """Return sensor specific state data to be restored.""" + return SensorExtraStoredData(self.native_value, self.native_unit_of_measurement) + + async def async_get_last_sensor_data(self) -> SensorExtraStoredData | None: + """Restore native_value and native_unit_of_measurement.""" + if (restored_last_extra_data := await self.async_get_last_extra_data()) is None: + return None + return SensorExtraStoredData.from_dict(restored_last_extra_data.as_dict()) diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index 4857210f125e..79d46f8ec2e6 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -1,6 +1,7 @@ """Support for restoring entity states on startup.""" from __future__ import annotations +from abc import abstractmethod import asyncio from datetime import datetime, timedelta import logging @@ -34,27 +35,65 @@ STATE_EXPIRATION = timedelta(days=7) _StoredStateT = TypeVar("_StoredStateT", bound="StoredState") +class ExtraStoredData: + """Object to hold extra stored data.""" + + @abstractmethod + def as_dict(self) -> dict[str, Any]: + """Return a dict representation of the extra data. + + Must be serializable by Home Assistant's JSONEncoder. + """ + + +class RestoredExtraData(ExtraStoredData): + """Object to hold extra stored data loaded from storage.""" + + def __init__(self, json_dict: dict[str, Any]) -> None: + """Object to hold extra stored data.""" + self.json_dict = json_dict + + def as_dict(self) -> dict[str, Any]: + """Return a dict representation of the extra data.""" + return self.json_dict + + class StoredState: """Object to represent a stored state.""" - def __init__(self, state: State, last_seen: datetime) -> None: + def __init__( + self, + state: State, + extra_data: ExtraStoredData | None, + last_seen: datetime, + ) -> None: """Initialize a new stored state.""" - self.state = state + self.extra_data = extra_data self.last_seen = last_seen + self.state = state def as_dict(self) -> dict[str, Any]: """Return a dict representation of the stored state.""" - return {"state": self.state.as_dict(), "last_seen": self.last_seen} + result = { + "state": self.state.as_dict(), + "extra_data": self.extra_data.as_dict() if self.extra_data else None, + "last_seen": self.last_seen, + } + return result @classmethod def from_dict(cls: type[_StoredStateT], json_dict: dict) -> _StoredStateT: """Initialize a stored state from a dict.""" + extra_data_dict = json_dict.get("extra_data") + extra_data = RestoredExtraData(extra_data_dict) if extra_data_dict else None last_seen = json_dict["last_seen"] if isinstance(last_seen, str): last_seen = dt_util.parse_datetime(last_seen) - return cls(cast(State, State.from_dict(json_dict["state"])), last_seen) + return cls( + cast(State, State.from_dict(json_dict["state"])), extra_data, last_seen + ) class RestoreStateData: @@ -104,7 +143,7 @@ class RestoreStateData: hass, STORAGE_VERSION, STORAGE_KEY, encoder=JSONEncoder ) self.last_states: dict[str, StoredState] = {} - self.entity_ids: set[str] = set() + self.entities: dict[str, RestoreEntity] = {} @callback def async_get_stored_states(self) -> list[StoredState]: @@ -125,9 +164,11 @@ class RestoreStateData: # Start with the currently registered states stored_states = [ - StoredState(state, now) + StoredState( + state, self.entities[state.entity_id].extra_restore_state_data, now + ) for state in all_states - if state.entity_id in self.entity_ids and + if state.entity_id in self.entities and # Ignore all states that are entity registry placeholders not state.attributes.get(ATTR_RESTORED) ] @@ -188,12 +229,14 @@ class RestoreStateData: ) @callback - def async_restore_entity_added(self, entity_id: str) -> None: + def async_restore_entity_added(self, entity: RestoreEntity) -> None: """Store this entity's state when hass is shutdown.""" - self.entity_ids.add(entity_id) + self.entities[entity.entity_id] = entity @callback - def async_restore_entity_removed(self, entity_id: str) -> None: + def async_restore_entity_removed( + self, entity_id: str, extra_data: ExtraStoredData | None + ) -> None: """Unregister this entity from saving state.""" # When an entity is being removed from hass, store its last state. This # allows us to support state restoration if the entity is removed, then @@ -204,9 +247,11 @@ class RestoreStateData: if state is not None: state = State.from_dict(_encode_complex(state.as_dict())) if state is not None: - self.last_states[entity_id] = StoredState(state, dt_util.utcnow()) + self.last_states[entity_id] = StoredState( + state, extra_data, dt_util.utcnow() + ) - self.entity_ids.remove(entity_id) + self.entities.pop(entity_id) def _encode(value: Any) -> Any: @@ -244,7 +289,7 @@ class RestoreEntity(Entity): super().async_internal_added_to_hass(), RestoreStateData.async_get_instance(self.hass), ) - data.async_restore_entity_added(self.entity_id) + data.async_restore_entity_added(self) async def async_internal_will_remove_from_hass(self) -> None: """Run when entity will be removed from hass.""" @@ -252,10 +297,10 @@ class RestoreEntity(Entity): super().async_internal_will_remove_from_hass(), RestoreStateData.async_get_instance(self.hass), ) - data.async_restore_entity_removed(self.entity_id) + data.async_restore_entity_removed(self.entity_id, self.extra_restore_state_data) - async def async_get_last_state(self) -> State | None: - """Get the entity state from the previous run.""" + async def _async_get_restored_data(self) -> StoredState | None: + """Get data stored for an entity, if any.""" if self.hass is None or self.entity_id is None: # Return None if this entity isn't added to hass yet _LOGGER.warning("Cannot get last state. Entity not added to hass") # type: ignore[unreachable] @@ -265,4 +310,24 @@ class RestoreEntity(Entity): ) if self.entity_id not in data.last_states: return None - return data.last_states[self.entity_id].state + return data.last_states[self.entity_id] + + async def async_get_last_state(self) -> State | None: + """Get the entity state from the previous run.""" + if (stored_state := await self._async_get_restored_data()) is None: + return None + return stored_state.state + + async def async_get_last_extra_data(self) -> ExtraStoredData | None: + """Get the entity specific state data from the previous run.""" + if (stored_state := await self._async_get_restored_data()) is None: + return None + return stored_state.extra_data + + @property + def extra_restore_state_data(self) -> ExtraStoredData | None: + """Return entity specific state data to be restored. + + Implemented by platform classes. + """ + return None diff --git a/tests/common.py b/tests/common.py index 3da0fcb98ddf..c8dfb3ed841d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -44,7 +44,7 @@ from homeassistant.const import ( STATE_OFF, STATE_ON, ) -from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, State +from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant from homeassistant.helpers import ( area_registry, device_registry, @@ -937,8 +937,33 @@ def mock_restore_cache(hass, states): json.dumps(restored_state["attributes"], cls=JSONEncoder) ), } - last_states[state.entity_id] = restore_state.StoredState( - State.from_dict(restored_state), now + last_states[state.entity_id] = restore_state.StoredState.from_dict( + {"state": restored_state, "last_seen": now} + ) + data.last_states = last_states + _LOGGER.debug("Restore cache: %s", data.last_states) + assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" + + hass.data[key] = data + + +def mock_restore_cache_with_extra_data(hass, states): + """Mock the DATA_RESTORE_CACHE.""" + key = restore_state.DATA_RESTORE_STATE_TASK + data = restore_state.RestoreStateData(hass) + now = date_util.utcnow() + + last_states = {} + for state, extra_data in states: + restored_state = state.as_dict() + restored_state = { + **restored_state, + "attributes": json.loads( + json.dumps(restored_state["attributes"], cls=JSONEncoder) + ), + } + last_states[state.entity_id] = restore_state.StoredState.from_dict( + {"state": restored_state, "extra_data": extra_data, "last_seen": now} ) data.last_states = last_states _LOGGER.debug("Restore cache: %s", data.last_states) diff --git a/tests/components/sensor/test_init.py b/tests/components/sensor/test_init.py index eed88d92d041..df33cb1a0815 100644 --- a/tests/components/sensor/test_init.py +++ b/tests/components/sensor/test_init.py @@ -11,10 +11,14 @@ from homeassistant.const import ( TEMP_CELSIUS, TEMP_FAHRENHEIT, ) +from homeassistant.core import State +from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM +from tests.common import mock_restore_cache_with_extra_data + @pytest.mark.parametrize( "unit_system,native_unit,state_unit,native_value,state_value", @@ -210,3 +214,131 @@ async def test_reject_timezoneless_datetime_str( "Invalid datetime: sensor.test provides state '2017-12-19 18:29:42', " "which is missing timezone information" ) in caplog.text + + +RESTORE_DATA = { + "str": {"native_unit_of_measurement": "°F", "native_value": "abc123"}, + "int": {"native_unit_of_measurement": "°F", "native_value": 123}, + "float": {"native_unit_of_measurement": "°F", "native_value": 123.0}, + "date": { + "native_unit_of_measurement": "°F", + "native_value": { + "__type": "", + "isoformat": date(2020, 2, 8).isoformat(), + }, + }, + "datetime": { + "native_unit_of_measurement": "°F", + "native_value": { + "__type": "", + "isoformat": datetime(2020, 2, 8, 15, tzinfo=timezone.utc).isoformat(), + }, + }, +} + + +# None | str | int | float | date | datetime: +@pytest.mark.parametrize( + "native_value, native_value_type, expected_extra_data, device_class", + [ + ("abc123", str, RESTORE_DATA["str"], None), + (123, int, RESTORE_DATA["int"], SensorDeviceClass.TEMPERATURE), + (123.0, float, RESTORE_DATA["float"], SensorDeviceClass.TEMPERATURE), + (date(2020, 2, 8), dict, RESTORE_DATA["date"], SensorDeviceClass.DATE), + ( + datetime(2020, 2, 8, 15, tzinfo=timezone.utc), + dict, + RESTORE_DATA["datetime"], + SensorDeviceClass.TIMESTAMP, + ), + ], +) +async def test_restore_sensor_save_state( + hass, + enable_custom_integrations, + hass_storage, + native_value, + native_value_type, + expected_extra_data, + device_class, +): + """Test RestoreSensor.""" + platform = getattr(hass.components, "test.sensor") + platform.init(empty=True) + platform.ENTITIES["0"] = platform.MockRestoreSensor( + name="Test", + native_value=native_value, + native_unit_of_measurement=TEMP_FAHRENHEIT, + device_class=device_class, + ) + + entity0 = platform.ENTITIES["0"] + assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}}) + await hass.async_block_till_done() + + # Trigger saving state + await hass.async_stop() + + assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1 + state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"] + assert state["entity_id"] == entity0.entity_id + extra_data = hass_storage[RESTORE_STATE_KEY]["data"][0]["extra_data"] + assert extra_data == expected_extra_data + assert type(extra_data["native_value"]) == native_value_type + + +@pytest.mark.parametrize( + "native_value, native_value_type, extra_data, device_class, uom", + [ + ("abc123", str, RESTORE_DATA["str"], None, "°F"), + (123, int, RESTORE_DATA["int"], SensorDeviceClass.TEMPERATURE, "°F"), + (123.0, float, RESTORE_DATA["float"], SensorDeviceClass.TEMPERATURE, "°F"), + (date(2020, 2, 8), date, RESTORE_DATA["date"], SensorDeviceClass.DATE, "°F"), + ( + datetime(2020, 2, 8, 15, tzinfo=timezone.utc), + datetime, + RESTORE_DATA["datetime"], + SensorDeviceClass.TIMESTAMP, + "°F", + ), + (None, type(None), None, None, None), + (None, type(None), {}, None, None), + (None, type(None), {"beer": 123}, None, None), + ( + None, + type(None), + {"native_unit_of_measurement": "°F", "native_value": {}}, + None, + None, + ), + ], +) +async def test_restore_sensor_restore_state( + hass, + enable_custom_integrations, + hass_storage, + native_value, + native_value_type, + extra_data, + device_class, + uom, +): + """Test RestoreSensor.""" + mock_restore_cache_with_extra_data(hass, ((State("sensor.test", ""), extra_data),)) + + platform = getattr(hass.components, "test.sensor") + platform.init(empty=True) + platform.ENTITIES["0"] = platform.MockRestoreSensor( + name="Test", + device_class=device_class, + ) + + entity0 = platform.ENTITIES["0"] + assert await async_setup_component(hass, "sensor", {"sensor": {"platform": "test"}}) + await hass.async_block_till_done() + + assert hass.states.get(entity0.entity_id) + + assert entity0.native_value == native_value + assert type(entity0.native_value) == native_value_type + assert entity0.native_unit_of_measurement == uom diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py index 79719b753261..efe951342fa2 100644 --- a/tests/helpers/test_restore_state.py +++ b/tests/helpers/test_restore_state.py @@ -22,9 +22,9 @@ async def test_caching_data(hass): """Test that we cache data.""" now = dt_util.utcnow() stored_states = [ - StoredState(State("input_boolean.b0", "on"), now), - StoredState(State("input_boolean.b1", "on"), now), - StoredState(State("input_boolean.b2", "on"), now), + StoredState(State("input_boolean.b0", "on"), None, now), + StoredState(State("input_boolean.b1", "on"), None, now), + StoredState(State("input_boolean.b2", "on"), None, now), ] data = await RestoreStateData.async_get_instance(hass) @@ -160,9 +160,9 @@ async def test_hass_starting(hass): now = dt_util.utcnow() stored_states = [ - StoredState(State("input_boolean.b0", "on"), now), - StoredState(State("input_boolean.b1", "on"), now), - StoredState(State("input_boolean.b2", "on"), now), + StoredState(State("input_boolean.b0", "on"), None, now), + StoredState(State("input_boolean.b1", "on"), None, now), + StoredState(State("input_boolean.b2", "on"), None, now), ] data = await RestoreStateData.async_get_instance(hass) @@ -225,15 +225,16 @@ async def test_dump_data(hass): data = await RestoreStateData.async_get_instance(hass) now = dt_util.utcnow() data.last_states = { - "input_boolean.b0": StoredState(State("input_boolean.b0", "off"), now), - "input_boolean.b1": StoredState(State("input_boolean.b1", "off"), now), - "input_boolean.b2": StoredState(State("input_boolean.b2", "off"), now), - "input_boolean.b3": StoredState(State("input_boolean.b3", "off"), now), + "input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now), + "input_boolean.b1": StoredState(State("input_boolean.b1", "off"), None, now), + "input_boolean.b2": StoredState(State("input_boolean.b2", "off"), None, now), + "input_boolean.b3": StoredState(State("input_boolean.b3", "off"), None, now), "input_boolean.b4": StoredState( State("input_boolean.b4", "off"), + None, datetime(1985, 10, 26, 1, 22, tzinfo=dt_util.UTC), ), - "input_boolean.b5": StoredState(State("input_boolean.b5", "off"), now), + "input_boolean.b5": StoredState(State("input_boolean.b5", "off"), None, now), } with patch( diff --git a/tests/testing_config/custom_components/test/sensor.py b/tests/testing_config/custom_components/test/sensor.py index 4ad2580ad8bc..56587c80c348 100644 --- a/tests/testing_config/custom_components/test/sensor.py +++ b/tests/testing_config/custom_components/test/sensor.py @@ -5,6 +5,7 @@ Call init before using it in your tests to ensure clean test data. """ from homeassistant.components.sensor import ( DEVICE_CLASSES, + RestoreSensor, SensorDeviceClass, SensorEntity, ) @@ -109,3 +110,17 @@ class MockSensor(MockEntity, SensorEntity): def state_class(self): """Return the state class of this sensor.""" return self._handle("state_class") + + +class MockRestoreSensor(MockSensor, RestoreSensor): + """Mock RestoreSensor class.""" + + async def async_added_to_hass(self) -> None: + """Restore native_value and native_unit_of_measurement.""" + await super().async_added_to_hass() + if (last_sensor_data := await self.async_get_last_sensor_data()) is None: + return + self._values["native_value"] = last_sensor_data.native_value + self._values[ + "native_unit_of_measurement" + ] = last_sensor_data.native_unit_of_measurement