Refactor async_get_hass to rely on threading.local instead of a ContextVar (#96005)

* Test for async_get_hass

* Add Fix
This commit is contained in:
Jan Bouwhuis 2023-07-07 20:52:38 +02:00 committed by GitHub
parent 372687fe81
commit 18ee9f4725
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 205 additions and 25 deletions

View File

@ -16,7 +16,6 @@ from collections.abc import (
) )
import concurrent.futures import concurrent.futures
from contextlib import suppress from contextlib import suppress
from contextvars import ContextVar
import datetime import datetime
import enum import enum
import functools import functools
@ -155,8 +154,6 @@ MAX_EXPECTED_ENTITY_IDS = 16384
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_cv_hass: ContextVar[HomeAssistant] = ContextVar("hass")
@functools.lru_cache(MAX_EXPECTED_ENTITY_IDS) @functools.lru_cache(MAX_EXPECTED_ENTITY_IDS)
def split_entity_id(entity_id: str) -> tuple[str, str]: def split_entity_id(entity_id: str) -> tuple[str, str]:
@ -199,16 +196,27 @@ def is_callback(func: Callable[..., Any]) -> bool:
return getattr(func, "_hass_callback", False) is True return getattr(func, "_hass_callback", False) is True
class _Hass(threading.local):
"""Container which makes a HomeAssistant instance available to the event loop."""
hass: HomeAssistant | None = None
_hass = _Hass()
@callback @callback
def async_get_hass() -> HomeAssistant: def async_get_hass() -> HomeAssistant:
"""Return the HomeAssistant instance. """Return the HomeAssistant instance.
Raises LookupError if no HomeAssistant instance is available. Raises HomeAssistantError when called from the wrong thread.
This should be used where it's very cumbersome or downright impossible to pass This should be used where it's very cumbersome or downright impossible to pass
hass to the code which needs it. hass to the code which needs it.
""" """
return _cv_hass.get() if not _hass.hass:
raise HomeAssistantError("async_get_hass called from the wrong thread")
return _hass.hass
@enum.unique @enum.unique
@ -292,9 +300,9 @@ class HomeAssistant:
config_entries: ConfigEntries = None # type: ignore[assignment] config_entries: ConfigEntries = None # type: ignore[assignment]
def __new__(cls) -> HomeAssistant: def __new__(cls) -> HomeAssistant:
"""Set the _cv_hass context variable.""" """Set the _hass thread local data."""
hass = super().__new__(cls) hass = super().__new__(cls)
_cv_hass.set(hass) _hass.hass = hass
return hass return hass
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -93,7 +93,7 @@ from homeassistant.core import (
split_entity_id, split_entity_id,
valid_entity_id, valid_entity_id,
) )
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.generated import currencies from homeassistant.generated import currencies
from homeassistant.generated.countries import COUNTRIES from homeassistant.generated.countries import COUNTRIES
from homeassistant.generated.languages import LANGUAGES from homeassistant.generated.languages import LANGUAGES
@ -609,7 +609,7 @@ def template(value: Any | None) -> template_helper.Template:
raise vol.Invalid("template value should be a string") raise vol.Invalid("template value should be a string")
hass: HomeAssistant | None = None hass: HomeAssistant | None = None
with contextlib.suppress(LookupError): with contextlib.suppress(HomeAssistantError):
hass = async_get_hass() hass = async_get_hass()
template_value = template_helper.Template(str(value), hass) template_value = template_helper.Template(str(value), hass)
@ -631,7 +631,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template:
raise vol.Invalid("template value does not contain a dynamic template") raise vol.Invalid("template value does not contain a dynamic template")
hass: HomeAssistant | None = None hass: HomeAssistant | None = None
with contextlib.suppress(LookupError): with contextlib.suppress(HomeAssistantError):
hass = async_get_hass() hass = async_get_hass()
template_value = template_helper.Template(str(value), hass) template_value = template_helper.Template(str(value), hass)
@ -1098,7 +1098,7 @@ def _no_yaml_config_schema(
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
from .issue_registry import IssueSeverity, async_create_issue from .issue_registry import IssueSeverity, async_create_issue
with contextlib.suppress(LookupError): with contextlib.suppress(HomeAssistantError):
hass = async_get_hass() hass = async_get_hass()
async_create_issue( async_create_issue(
hass, hass,

View File

@ -490,17 +490,7 @@ def hass_fixture_setup() -> list[bool]:
@pytest.fixture @pytest.fixture
def hass(_hass: HomeAssistant) -> HomeAssistant: async def hass(
"""Fixture to provide a test instance of Home Assistant."""
# This wraps the async _hass fixture inside a sync fixture, to ensure
# the `hass` context variable is set in the execution context in which
# the test itself is executed
ha._cv_hass.set(_hass)
return _hass
@pytest.fixture
async def _hass(
hass_fixture_setup: list[bool], hass_fixture_setup: list[bool],
event_loop: asyncio.AbstractEventLoop, event_loop: asyncio.AbstractEventLoop,
load_registries: bool, load_registries: bool,

View File

@ -12,6 +12,7 @@ import voluptuous as vol
import homeassistant import homeassistant
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import ( from homeassistant.helpers import (
config_validation as cv, config_validation as cv,
issue_registry as ir, issue_registry as ir,
@ -383,7 +384,7 @@ def test_service() -> None:
schema("homeassistant.turn_on") schema("homeassistant.turn_on")
def test_service_schema() -> None: def test_service_schema(hass: HomeAssistant) -> None:
"""Test service_schema validation.""" """Test service_schema validation."""
options = ( options = (
{}, {},
@ -1550,10 +1551,10 @@ def test_config_entry_only_schema_cant_find_module() -> None:
def test_config_entry_only_schema_no_hass( def test_config_entry_only_schema_no_hass(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test if the the hass context var is not set in our context.""" """Test if the the hass context is not set in our context."""
with patch( with patch(
"homeassistant.helpers.config_validation.async_get_hass", "homeassistant.helpers.config_validation.async_get_hass",
side_effect=LookupError, side_effect=HomeAssistantError,
): ):
cv.config_entry_only_config_schema("test_domain")( cv.config_entry_only_config_schema("test_domain")(
{"test_domain": {"foo": "bar"}} {"test_domain": {"foo": "bar"}}

View File

@ -9,10 +9,12 @@ import gc
import logging import logging
import os import os
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import threading
import time import time
from typing import Any from typing import Any
from unittest.mock import MagicMock, Mock, PropertyMock, patch from unittest.mock import MagicMock, Mock, PropertyMock, patch
import async_timeout
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -40,6 +42,7 @@ from homeassistant.core import (
ServiceResponse, ServiceResponse,
State, State,
SupportsResponse, SupportsResponse,
callback,
) )
from homeassistant.exceptions import ( from homeassistant.exceptions import (
HomeAssistantError, HomeAssistantError,
@ -202,6 +205,184 @@ def test_async_run_hass_job_delegates_non_async() -> None:
assert len(hass.async_add_hass_job.mock_calls) == 1 assert len(hass.async_add_hass_job.mock_calls) == 1
async def test_async_get_hass_can_be_called(hass: HomeAssistant) -> None:
"""Test calling async_get_hass via different paths.
The test asserts async_get_hass can be called from:
- Coroutines and callbacks
- Callbacks scheduled from callbacks, coroutines and threads
- Coroutines scheduled from callbacks, coroutines and threads
The test also asserts async_get_hass can not be called from threads
other than the event loop.
"""
task_finished = asyncio.Event()
def can_call_async_get_hass() -> bool:
"""Test if it's possible to call async_get_hass."""
try:
if ha.async_get_hass() is hass:
return True
raise Exception
except HomeAssistantError:
return False
raise Exception
# Test scheduling a coroutine which calls async_get_hass via hass.async_create_task
async def _async_create_task() -> None:
task_finished.set()
assert can_call_async_get_hass()
hass.async_create_task(_async_create_task(), "create_task")
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass via hass.async_add_job
@callback
def _add_job() -> None:
assert can_call_async_get_hass()
task_finished.set()
hass.async_add_job(_add_job)
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass from a callback
@callback
def _schedule_callback_from_callback() -> None:
@callback
def _callback():
assert can_call_async_get_hass()
task_finished.set()
# Test the scheduled callback itself can call async_get_hass
assert can_call_async_get_hass()
hass.async_add_job(_callback)
_schedule_callback_from_callback()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a coroutine which calls async_get_hass from a callback
@callback
def _schedule_coroutine_from_callback() -> None:
async def _coroutine():
assert can_call_async_get_hass()
task_finished.set()
# Test the scheduled callback itself can call async_get_hass
assert can_call_async_get_hass()
hass.async_add_job(_coroutine())
_schedule_coroutine_from_callback()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass from a coroutine
async def _schedule_callback_from_coroutine() -> None:
@callback
def _callback():
assert can_call_async_get_hass()
task_finished.set()
# Test the coroutine itself can call async_get_hass
assert can_call_async_get_hass()
hass.async_add_job(_callback)
await _schedule_callback_from_coroutine()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a coroutine which calls async_get_hass from a coroutine
async def _schedule_callback_from_coroutine() -> None:
async def _coroutine():
assert can_call_async_get_hass()
task_finished.set()
# Test the coroutine itself can call async_get_hass
assert can_call_async_get_hass()
await hass.async_create_task(_coroutine())
await _schedule_callback_from_coroutine()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass from an executor
def _async_add_executor_job_add_job() -> None:
@callback
def _async_add_job():
assert can_call_async_get_hass()
task_finished.set()
# Test the executor itself can not call async_get_hass
assert not can_call_async_get_hass()
hass.add_job(_async_add_job)
await hass.async_add_executor_job(_async_add_executor_job_add_job)
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a coroutine which calls async_get_hass from an executor
def _async_add_executor_job_create_task() -> None:
async def _async_create_task() -> None:
assert can_call_async_get_hass()
task_finished.set()
# Test the executor itself can not call async_get_hass
assert not can_call_async_get_hass()
hass.create_task(_async_create_task())
await hass.async_add_executor_job(_async_add_executor_job_create_task)
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass from a worker thread
class MyJobAddJob(threading.Thread):
@callback
def _my_threaded_job_add_job(self) -> None:
assert can_call_async_get_hass()
task_finished.set()
def run(self) -> None:
# Test the worker thread itself can not call async_get_hass
assert not can_call_async_get_hass()
hass.add_job(self._my_threaded_job_add_job)
my_job_add_job = MyJobAddJob()
my_job_add_job.start()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
my_job_add_job.join()
# Test scheduling a coroutine which calls async_get_hass from a worker thread
class MyJobCreateTask(threading.Thread):
async def _my_threaded_job_create_task(self) -> None:
assert can_call_async_get_hass()
task_finished.set()
def run(self) -> None:
# Test the worker thread itself can not call async_get_hass
assert not can_call_async_get_hass()
hass.create_task(self._my_threaded_job_create_task())
my_job_create_task = MyJobCreateTask()
my_job_create_task.start()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
my_job_create_task.join()
async def test_stage_shutdown(hass: HomeAssistant) -> None: async def test_stage_shutdown(hass: HomeAssistant) -> None:
"""Simulate a shutdown, test calling stuff.""" """Simulate a shutdown, test calling stuff."""
test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP) test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)