From 79b488981239b5f673ba12bc5d808c9c4c2973b9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 24 Apr 2024 10:05:52 +0200 Subject: [PATCH] Always do thread safety checks when writing state for custom components (#116044) --- homeassistant/helpers/entity.py | 25 +++++++++++++++++++++---- tests/helpers/test_entity.py | 26 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 40b145727a12..cf21882eec8a 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -521,6 +521,7 @@ class Entity( # While not purely typed, it makes typehinting more useful for us # and removes the need for constant None checks or asserts. _state_info: StateInfo = None # type: ignore[assignment] + _is_custom_component: bool = False __capabilities_updated_at: deque[float] __capabilities_updated_at_reported: bool = False @@ -967,8 +968,8 @@ class Entity( self._async_write_ha_state() @callback - def async_write_ha_state(self) -> None: - """Write the state to the state machine.""" + def _async_verify_state_writable(self) -> None: + """Verify the entity is in a writable state.""" if self.hass is None: raise RuntimeError(f"Attribute hass is None for {self}") if self.hass.config.debug: @@ -995,6 +996,18 @@ class Entity( f"No entity id specified for entity {self.name}" ) + @callback + def _async_write_ha_state_from_call_soon_threadsafe(self) -> None: + """Write the state to the state machine from the event loop thread.""" + self._async_verify_state_writable() + self._async_write_ha_state() + + @callback + def async_write_ha_state(self) -> None: + """Write the state to the state machine.""" + self._async_verify_state_writable() + if self._is_custom_component or self.hass.config.debug: + self.hass.verify_event_loop_thread("async_write_ha_state") self._async_write_ha_state() def _stringify_state(self, available: bool) -> str: @@ -1221,7 +1234,9 @@ class Entity( f"Entity {self.entity_id} schedule update ha state", ) else: - self.hass.loop.call_soon_threadsafe(self.async_write_ha_state) + self.hass.loop.call_soon_threadsafe( + self._async_write_ha_state_from_call_soon_threadsafe + ) @callback def async_schedule_update_ha_state(self, force_refresh: bool = False) -> None: @@ -1426,10 +1441,12 @@ class Entity( Not to be extended by integrations. """ + is_custom_component = "custom_components" in type(self).__module__ entity_info: EntityInfo = { "domain": self.platform.platform_name, - "custom_component": "custom_components" in type(self).__module__, + "custom_component": is_custom_component, } + self._is_custom_component = is_custom_component if self.platform.config_entry: entity_info["config_entry"] = self.platform.config_entry.entry_id diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 349c065f9b53..a80674e0f764 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -2615,3 +2615,29 @@ async def test_async_write_ha_state_thread_safety(hass: HomeAssistant) -> None: ): await hass.async_add_executor_job(ent2.async_write_ha_state) assert not hass.states.get(ent2.entity_id) + + +async def test_async_write_ha_state_thread_safety_custom_component( + hass: HomeAssistant, +) -> None: + """Test async_write_ha_state thread safe for custom components.""" + + ent = entity.Entity() + ent._is_custom_component = True + ent.entity_id = "test.any" + ent.hass = hass + ent.platform = MockEntityPlatform(hass, domain="test") + ent.async_write_ha_state() + assert hass.states.get(ent.entity_id) + + ent2 = entity.Entity() + ent2._is_custom_component = True + ent2.entity_id = "test.any2" + ent2.hass = hass + ent2.platform = MockEntityPlatform(hass, domain="test") + with pytest.raises( + RuntimeError, + match="Detected code that calls async_write_ha_state from a thread.", + ): + await hass.async_add_executor_job(ent2.async_write_ha_state) + assert not hass.states.get(ent2.entity_id)