From 049c06061ce92834b0c82b0e8b06ae7520322e54 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 26 May 2022 17:54:26 -1000 Subject: [PATCH] Fix memory leak when firing state_changed events (#72571) --- homeassistant/components/recorder/models.py | 2 +- homeassistant/core.py | 45 +++++++++++++++++---- tests/test_core.py | 45 +++++++++++++++++++++ 3 files changed, 84 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index dff8edde79f8..70c816c2af5f 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -746,7 +746,7 @@ class LazyState(State): def context(self) -> Context: # type: ignore[override] """State context.""" if self._context is None: - self._context = Context(id=None) # type: ignore[arg-type] + self._context = Context(id=None) return self._context @context.setter diff --git a/homeassistant/core.py b/homeassistant/core.py index 2753b8013471..d7cae4e411ea 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -37,7 +37,6 @@ from typing import ( ) from urllib.parse import urlparse -import attr import voluptuous as vol import yarl @@ -716,14 +715,26 @@ class HomeAssistant: self._stopped.set() -@attr.s(slots=True, frozen=False) class Context: """The context that triggered something.""" - user_id: str | None = attr.ib(default=None) - parent_id: str | None = attr.ib(default=None) - id: str = attr.ib(factory=ulid_util.ulid) - origin_event: Event | None = attr.ib(default=None, eq=False) + __slots__ = ("user_id", "parent_id", "id", "origin_event") + + def __init__( + self, + user_id: str | None = None, + parent_id: str | None = None, + id: str | None = None, # pylint: disable=redefined-builtin + ) -> None: + """Init the context.""" + self.id = id or ulid_util.ulid() + self.user_id = user_id + self.parent_id = parent_id + self.origin_event: Event | None = None + + def __eq__(self, other: Any) -> bool: + """Compare contexts.""" + return bool(self.__class__ == other.__class__ and self.id == other.id) def as_dict(self) -> dict[str, str | None]: """Return a dictionary representation of the context.""" @@ -1163,6 +1174,24 @@ class State: context, ) + def expire(self) -> None: + """Mark the state as old. + + We give up the original reference to the context to ensure + the context can be garbage collected by replacing it with + a new one with the same id to ensure the old state + can still be examined for comparison against the new state. + + Since we are always going to fire a EVENT_STATE_CHANGED event + after we remove a state from the state machine we need to make + sure we don't end up holding a reference to the original context + since it can never be garbage collected as each event would + reference the previous one. + """ + self.context = Context( + self.context.user_id, self.context.parent_id, self.context.id + ) + def __eq__(self, other: Any) -> bool: """Return the comparison of the state.""" return ( # type: ignore[no-any-return] @@ -1303,6 +1332,7 @@ class StateMachine: if old_state is None: return False + old_state.expire() self._bus.async_fire( EVENT_STATE_CHANGED, {"entity_id": entity_id, "old_state": old_state, "new_state": None}, @@ -1396,7 +1426,6 @@ class StateMachine: if context is None: context = Context(id=ulid_util.ulid(dt_util.utc_to_timestamp(now))) - state = State( entity_id, new_state, @@ -1406,6 +1435,8 @@ class StateMachine: context, old_state is None, ) + if old_state is not None: + old_state.expire() self._states[entity_id] = state self._bus.async_fire( EVENT_STATE_CHANGED, diff --git a/tests/test_core.py b/tests/test_core.py index ee1005a60b04..67513ea8b17b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,9 +6,11 @@ import array import asyncio from datetime import datetime, timedelta import functools +import gc import logging import os from tempfile import TemporaryDirectory +from typing import Any from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest @@ -1829,3 +1831,46 @@ async def test_event_context(hass): cancel2() assert dummy_event2.context.origin_event == dummy_event + + +def _get_full_name(obj) -> str: + """Get the full name of an object in memory.""" + objtype = type(obj) + name = objtype.__name__ + if module := getattr(objtype, "__module__", None): + return f"{module}.{name}" + return name + + +def _get_by_type(full_name: str) -> list[Any]: + """Get all objects in memory with a specific type.""" + return [obj for obj in gc.get_objects() if _get_full_name(obj) == full_name] + + +# The logger will hold a strong reference to the event for the life of the tests +# so we must patch it out +@pytest.mark.skipif( + not os.environ.get("DEBUG_MEMORY"), + reason="Takes too long on the CI", +) +@patch.object(ha._LOGGER, "debug", lambda *args: None) +async def test_state_changed_events_to_not_leak_contexts(hass): + """Test state changed events do not leak contexts.""" + gc.collect() + # Other tests can log Contexts which keep them in memory + # so we need to look at how many exist at the start + init_count = len(_get_by_type("homeassistant.core.Context")) + + assert len(_get_by_type("homeassistant.core.Context")) == init_count + for i in range(20): + hass.states.async_set("light.switch", str(i)) + await hass.async_block_till_done() + gc.collect() + + assert len(_get_by_type("homeassistant.core.Context")) == init_count + 2 + + hass.states.async_remove("light.switch") + await hass.async_block_till_done() + gc.collect() + + assert len(_get_by_type("homeassistant.core.Context")) == init_count