diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 8e4f6bc8b588..a554a093c5c2 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -6,6 +6,7 @@ import asyncio from collections.abc import Awaitable, Iterable, Mapping, MutableMapping from dataclasses import dataclass from datetime import datetime, timedelta +from enum import Enum, auto import functools as ft import logging import math @@ -207,6 +208,19 @@ class EntityCategory(StrEnum): SYSTEM = "system" +class EntityPlatformState(Enum): + """The platform state of an entity.""" + + # Not Added: Not yet added to a platform, polling updates are written to the state machine + NOT_ADDED = auto() + + # Added: Added to a platform, polling updates are written to the state machine + ADDED = auto() + + # Removed: Removed from a platform, polling updates are not written to the state machine + REMOVED = auto() + + def convert_to_entity_category( value: EntityCategory | str | None, raise_report: bool = True ) -> EntityCategory | None: @@ -294,7 +308,7 @@ class Entity(ABC): _context_set: datetime | None = None # If entity is added to an entity platform - _added = False + _platform_state = EntityPlatformState.NOT_ADDED # Entity Properties _attr_assumed_state: bool = False @@ -553,6 +567,10 @@ class Entity(ABC): @callback def _async_write_ha_state(self) -> None: """Write the state to the state machine.""" + if self._platform_state == EntityPlatformState.REMOVED: + # Polling returned after the entity has already been removed + return + if self.registry_entry and self.registry_entry.disabled_by: if not self._disabled_reported: self._disabled_reported = True @@ -758,7 +776,7 @@ class Entity(ABC): parallel_updates: asyncio.Semaphore | None, ) -> None: """Start adding an entity to a platform.""" - if self._added: + if self._platform_state == EntityPlatformState.ADDED: raise HomeAssistantError( f"Entity {self.entity_id} cannot be added a second time to an entity platform" ) @@ -766,7 +784,7 @@ class Entity(ABC): self.hass = hass self.platform = platform self.parallel_updates = parallel_updates - self._added = True + self._platform_state = EntityPlatformState.ADDED @callback def add_to_platform_abort(self) -> None: @@ -774,7 +792,7 @@ class Entity(ABC): self.hass = None # type: ignore[assignment] self.platform = None self.parallel_updates = None - self._added = False + self._platform_state = EntityPlatformState.NOT_ADDED async def add_to_platform_finish(self) -> None: """Finish adding an entity to a platform.""" @@ -792,12 +810,12 @@ class Entity(ABC): If the entity doesn't have a non disabled entry in the entity registry, or if force_remove=True, its state will be removed. """ - if self.platform and not self._added: + if self.platform and self._platform_state != EntityPlatformState.ADDED: raise HomeAssistantError( f"Entity {self.entity_id} async_remove called twice" ) - self._added = False + self._platform_state = EntityPlatformState.REMOVED if self._on_remove is not None: while self._on_remove: diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 6b7de074a24d..afc0887371ee 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -545,6 +545,22 @@ async def test_async_remove_runs_callbacks(hass): assert len(result) == 1 +async def test_async_remove_ignores_in_flight_polling(hass): + """Test in flight polling is ignored after removing.""" + result = [] + + ent = entity.Entity() + ent.hass = hass + ent.entity_id = "test.test" + ent.async_on_remove(lambda: result.append(1)) + ent.async_write_ha_state() + assert hass.states.get("test.test").state == STATE_UNKNOWN + await ent.async_remove() + assert len(result) == 1 + assert hass.states.get("test.test") is None + ent.async_write_ha_state() + + async def test_set_context(hass): """Test setting context.""" context = Context() diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 9aa0a849e5a8..c98fdff78586 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -390,6 +390,30 @@ async def test_async_remove_with_platform(hass): assert len(hass.states.async_entity_ids()) == 0 +async def test_async_remove_with_platform_update_finishes(hass): + """Remove an entity when an update finishes after its been removed.""" + component = EntityComponent(_LOGGER, DOMAIN, hass) + entity1 = MockEntity(name="test_1") + + async def _delayed_update(*args, **kwargs): + await asyncio.sleep(0.01) + + entity1.async_update = _delayed_update + + # Add, remove, add, remove and make sure no updates + # cause the entity to reappear after removal + for i in range(2): + await component.async_add_entities([entity1]) + assert len(hass.states.async_entity_ids()) == 1 + entity1.async_write_ha_state() + assert hass.states.get(entity1.entity_id) is not None + task = asyncio.create_task(entity1.async_update_ha_state(True)) + await entity1.async_remove() + assert len(hass.states.async_entity_ids()) == 0 + await task + assert len(hass.states.async_entity_ids()) == 0 + + async def test_not_adding_duplicate_entities_with_unique_id(hass, caplog): """Test for not adding duplicate entities.""" caplog.set_level(logging.ERROR)