Always do thread safety checks when writing state for custom components (#116044)

This commit is contained in:
J. Nick Koston 2024-04-24 10:05:52 +02:00 committed by GitHub
parent 6f2a2ba46e
commit 79b4889812
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 4 deletions

View File

@ -521,6 +521,7 @@ class Entity(
# While not purely typed, it makes typehinting more useful for us
# and removes the need for constant None checks or asserts.
_state_info: StateInfo = None # type: ignore[assignment]
_is_custom_component: bool = False
__capabilities_updated_at: deque[float]
__capabilities_updated_at_reported: bool = False
@ -967,8 +968,8 @@ class Entity(
self._async_write_ha_state()
@callback
def async_write_ha_state(self) -> None:
"""Write the state to the state machine."""
def _async_verify_state_writable(self) -> None:
"""Verify the entity is in a writable state."""
if self.hass is None:
raise RuntimeError(f"Attribute hass is None for {self}")
if self.hass.config.debug:
@ -995,6 +996,18 @@ class Entity(
f"No entity id specified for entity {self.name}"
)
@callback
def _async_write_ha_state_from_call_soon_threadsafe(self) -> None:
"""Write the state to the state machine from the event loop thread."""
self._async_verify_state_writable()
self._async_write_ha_state()
@callback
def async_write_ha_state(self) -> None:
"""Write the state to the state machine."""
self._async_verify_state_writable()
if self._is_custom_component or self.hass.config.debug:
self.hass.verify_event_loop_thread("async_write_ha_state")
self._async_write_ha_state()
def _stringify_state(self, available: bool) -> str:
@ -1221,7 +1234,9 @@ class Entity(
f"Entity {self.entity_id} schedule update ha state",
)
else:
self.hass.loop.call_soon_threadsafe(self.async_write_ha_state)
self.hass.loop.call_soon_threadsafe(
self._async_write_ha_state_from_call_soon_threadsafe
)
@callback
def async_schedule_update_ha_state(self, force_refresh: bool = False) -> None:
@ -1426,10 +1441,12 @@ class Entity(
Not to be extended by integrations.
"""
is_custom_component = "custom_components" in type(self).__module__
entity_info: EntityInfo = {
"domain": self.platform.platform_name,
"custom_component": "custom_components" in type(self).__module__,
"custom_component": is_custom_component,
}
self._is_custom_component = is_custom_component
if self.platform.config_entry:
entity_info["config_entry"] = self.platform.config_entry.entry_id

View File

@ -2615,3 +2615,29 @@ async def test_async_write_ha_state_thread_safety(hass: HomeAssistant) -> None:
):
await hass.async_add_executor_job(ent2.async_write_ha_state)
assert not hass.states.get(ent2.entity_id)
async def test_async_write_ha_state_thread_safety_custom_component(
hass: HomeAssistant,
) -> None:
"""Test async_write_ha_state thread safe for custom components."""
ent = entity.Entity()
ent._is_custom_component = True
ent.entity_id = "test.any"
ent.hass = hass
ent.platform = MockEntityPlatform(hass, domain="test")
ent.async_write_ha_state()
assert hass.states.get(ent.entity_id)
ent2 = entity.Entity()
ent2._is_custom_component = True
ent2.entity_id = "test.any2"
ent2.hass = hass
ent2.platform = MockEntityPlatform(hass, domain="test")
with pytest.raises(
RuntimeError,
match="Detected code that calls async_write_ha_state from a thread.",
):
await hass.async_add_executor_job(ent2.async_write_ha_state)
assert not hass.states.get(ent2.entity_id)