Prevent polling from recreating an entity after removal (#67750)

This commit is contained in:
J. Nick Koston 2022-03-08 05:42:16 +01:00 committed by GitHub
parent aed2c1cce8
commit a75bbc79a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 64 additions and 6 deletions

View File

@ -6,6 +6,7 @@ import asyncio
from collections.abc import Awaitable, Iterable, Mapping, MutableMapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum, auto
import functools as ft
import logging
import math
@ -207,6 +208,19 @@ class EntityCategory(StrEnum):
SYSTEM = "system"
class EntityPlatformState(Enum):
"""The platform state of an entity."""
# Not Added: Not yet added to a platform, polling updates are written to the state machine
NOT_ADDED = auto()
# Added: Added to a platform, polling updates are written to the state machine
ADDED = auto()
# Removed: Removed from a platform, polling updates are not written to the state machine
REMOVED = auto()
def convert_to_entity_category(
value: EntityCategory | str | None, raise_report: bool = True
) -> EntityCategory | None:
@ -294,7 +308,7 @@ class Entity(ABC):
_context_set: datetime | None = None
# If entity is added to an entity platform
_added = False
_platform_state = EntityPlatformState.NOT_ADDED
# Entity Properties
_attr_assumed_state: bool = False
@ -553,6 +567,10 @@ class Entity(ABC):
@callback
def _async_write_ha_state(self) -> None:
"""Write the state to the state machine."""
if self._platform_state == EntityPlatformState.REMOVED:
# Polling returned after the entity has already been removed
return
if self.registry_entry and self.registry_entry.disabled_by:
if not self._disabled_reported:
self._disabled_reported = True
@ -758,7 +776,7 @@ class Entity(ABC):
parallel_updates: asyncio.Semaphore | None,
) -> None:
"""Start adding an entity to a platform."""
if self._added:
if self._platform_state == EntityPlatformState.ADDED:
raise HomeAssistantError(
f"Entity {self.entity_id} cannot be added a second time to an entity platform"
)
@ -766,7 +784,7 @@ class Entity(ABC):
self.hass = hass
self.platform = platform
self.parallel_updates = parallel_updates
self._added = True
self._platform_state = EntityPlatformState.ADDED
@callback
def add_to_platform_abort(self) -> None:
@ -774,7 +792,7 @@ class Entity(ABC):
self.hass = None # type: ignore[assignment]
self.platform = None
self.parallel_updates = None
self._added = False
self._platform_state = EntityPlatformState.NOT_ADDED
async def add_to_platform_finish(self) -> None:
"""Finish adding an entity to a platform."""
@ -792,12 +810,12 @@ class Entity(ABC):
If the entity doesn't have a non disabled entry in the entity registry,
or if force_remove=True, its state will be removed.
"""
if self.platform and not self._added:
if self.platform and self._platform_state != EntityPlatformState.ADDED:
raise HomeAssistantError(
f"Entity {self.entity_id} async_remove called twice"
)
self._added = False
self._platform_state = EntityPlatformState.REMOVED
if self._on_remove is not None:
while self._on_remove:

View File

@ -545,6 +545,22 @@ async def test_async_remove_runs_callbacks(hass):
assert len(result) == 1
async def test_async_remove_ignores_in_flight_polling(hass):
"""Test in flight polling is ignored after removing."""
result = []
ent = entity.Entity()
ent.hass = hass
ent.entity_id = "test.test"
ent.async_on_remove(lambda: result.append(1))
ent.async_write_ha_state()
assert hass.states.get("test.test").state == STATE_UNKNOWN
await ent.async_remove()
assert len(result) == 1
assert hass.states.get("test.test") is None
ent.async_write_ha_state()
async def test_set_context(hass):
"""Test setting context."""
context = Context()

View File

@ -390,6 +390,30 @@ async def test_async_remove_with_platform(hass):
assert len(hass.states.async_entity_ids()) == 0
async def test_async_remove_with_platform_update_finishes(hass):
"""Remove an entity when an update finishes after its been removed."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = MockEntity(name="test_1")
async def _delayed_update(*args, **kwargs):
await asyncio.sleep(0.01)
entity1.async_update = _delayed_update
# Add, remove, add, remove and make sure no updates
# cause the entity to reappear after removal
for i in range(2):
await component.async_add_entities([entity1])
assert len(hass.states.async_entity_ids()) == 1
entity1.async_write_ha_state()
assert hass.states.get(entity1.entity_id) is not None
task = asyncio.create_task(entity1.async_update_ha_state(True))
await entity1.async_remove()
assert len(hass.states.async_entity_ids()) == 0
await task
assert len(hass.states.async_entity_ids()) == 0
async def test_not_adding_duplicate_entities_with_unique_id(hass, caplog):
"""Test for not adding duplicate entities."""
caplog.set_level(logging.ERROR)