Refactor async_get_hass to rely on threading.local instead of a ContextVar (#96005)

* Test for async_get_hass

* Add Fix
This commit is contained in:
Jan Bouwhuis 2023-07-07 20:52:38 +02:00 committed by GitHub
parent 372687fe81
commit 18ee9f4725
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 205 additions and 25 deletions

View File

@ -16,7 +16,6 @@ from collections.abc import (
)
import concurrent.futures
from contextlib import suppress
from contextvars import ContextVar
import datetime
import enum
import functools
@ -155,8 +154,6 @@ MAX_EXPECTED_ENTITY_IDS = 16384
_LOGGER = logging.getLogger(__name__)
_cv_hass: ContextVar[HomeAssistant] = ContextVar("hass")
@functools.lru_cache(MAX_EXPECTED_ENTITY_IDS)
def split_entity_id(entity_id: str) -> tuple[str, str]:
@ -199,16 +196,27 @@ def is_callback(func: Callable[..., Any]) -> bool:
return getattr(func, "_hass_callback", False) is True
class _Hass(threading.local):
"""Container which makes a HomeAssistant instance available to the event loop."""
hass: HomeAssistant | None = None
_hass = _Hass()
@callback
def async_get_hass() -> HomeAssistant:
"""Return the HomeAssistant instance.
Raises LookupError if no HomeAssistant instance is available.
Raises HomeAssistantError when called from the wrong thread.
This should be used where it's very cumbersome or downright impossible to pass
hass to the code which needs it.
"""
return _cv_hass.get()
if not _hass.hass:
raise HomeAssistantError("async_get_hass called from the wrong thread")
return _hass.hass
@enum.unique
@ -292,9 +300,9 @@ class HomeAssistant:
config_entries: ConfigEntries = None # type: ignore[assignment]
def __new__(cls) -> HomeAssistant:
"""Set the _cv_hass context variable."""
"""Set the _hass thread local data."""
hass = super().__new__(cls)
_cv_hass.set(hass)
_hass.hass = hass
return hass
def __init__(self) -> None:

View File

@ -93,7 +93,7 @@ from homeassistant.core import (
split_entity_id,
valid_entity_id,
)
from homeassistant.exceptions import TemplateError
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.generated import currencies
from homeassistant.generated.countries import COUNTRIES
from homeassistant.generated.languages import LANGUAGES
@ -609,7 +609,7 @@ def template(value: Any | None) -> template_helper.Template:
raise vol.Invalid("template value should be a string")
hass: HomeAssistant | None = None
with contextlib.suppress(LookupError):
with contextlib.suppress(HomeAssistantError):
hass = async_get_hass()
template_value = template_helper.Template(str(value), hass)
@ -631,7 +631,7 @@ def dynamic_template(value: Any | None) -> template_helper.Template:
raise vol.Invalid("template value does not contain a dynamic template")
hass: HomeAssistant | None = None
with contextlib.suppress(LookupError):
with contextlib.suppress(HomeAssistantError):
hass = async_get_hass()
template_value = template_helper.Template(str(value), hass)
@ -1098,7 +1098,7 @@ def _no_yaml_config_schema(
# pylint: disable-next=import-outside-toplevel
from .issue_registry import IssueSeverity, async_create_issue
with contextlib.suppress(LookupError):
with contextlib.suppress(HomeAssistantError):
hass = async_get_hass()
async_create_issue(
hass,

View File

@ -490,17 +490,7 @@ def hass_fixture_setup() -> list[bool]:
@pytest.fixture
def hass(_hass: HomeAssistant) -> HomeAssistant:
"""Fixture to provide a test instance of Home Assistant."""
# This wraps the async _hass fixture inside a sync fixture, to ensure
# the `hass` context variable is set in the execution context in which
# the test itself is executed
ha._cv_hass.set(_hass)
return _hass
@pytest.fixture
async def _hass(
async def hass(
hass_fixture_setup: list[bool],
event_loop: asyncio.AbstractEventLoop,
load_registries: bool,

View File

@ -12,6 +12,7 @@ import voluptuous as vol
import homeassistant
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
config_validation as cv,
issue_registry as ir,
@ -383,7 +384,7 @@ def test_service() -> None:
schema("homeassistant.turn_on")
def test_service_schema() -> None:
def test_service_schema(hass: HomeAssistant) -> None:
"""Test service_schema validation."""
options = (
{},
@ -1550,10 +1551,10 @@ def test_config_entry_only_schema_cant_find_module() -> None:
def test_config_entry_only_schema_no_hass(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test if the the hass context var is not set in our context."""
"""Test if the the hass context is not set in our context."""
with patch(
"homeassistant.helpers.config_validation.async_get_hass",
side_effect=LookupError,
side_effect=HomeAssistantError,
):
cv.config_entry_only_config_schema("test_domain")(
{"test_domain": {"foo": "bar"}}

View File

@ -9,10 +9,12 @@ import gc
import logging
import os
from tempfile import TemporaryDirectory
import threading
import time
from typing import Any
from unittest.mock import MagicMock, Mock, PropertyMock, patch
import async_timeout
import pytest
import voluptuous as vol
@ -40,6 +42,7 @@ from homeassistant.core import (
ServiceResponse,
State,
SupportsResponse,
callback,
)
from homeassistant.exceptions import (
HomeAssistantError,
@ -202,6 +205,184 @@ def test_async_run_hass_job_delegates_non_async() -> None:
assert len(hass.async_add_hass_job.mock_calls) == 1
async def test_async_get_hass_can_be_called(hass: HomeAssistant) -> None:
"""Test calling async_get_hass via different paths.
The test asserts async_get_hass can be called from:
- Coroutines and callbacks
- Callbacks scheduled from callbacks, coroutines and threads
- Coroutines scheduled from callbacks, coroutines and threads
The test also asserts async_get_hass can not be called from threads
other than the event loop.
"""
task_finished = asyncio.Event()
def can_call_async_get_hass() -> bool:
"""Test if it's possible to call async_get_hass."""
try:
if ha.async_get_hass() is hass:
return True
raise Exception
except HomeAssistantError:
return False
raise Exception
# Test scheduling a coroutine which calls async_get_hass via hass.async_create_task
async def _async_create_task() -> None:
task_finished.set()
assert can_call_async_get_hass()
hass.async_create_task(_async_create_task(), "create_task")
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass via hass.async_add_job
@callback
def _add_job() -> None:
assert can_call_async_get_hass()
task_finished.set()
hass.async_add_job(_add_job)
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass from a callback
@callback
def _schedule_callback_from_callback() -> None:
@callback
def _callback():
assert can_call_async_get_hass()
task_finished.set()
# Test the scheduled callback itself can call async_get_hass
assert can_call_async_get_hass()
hass.async_add_job(_callback)
_schedule_callback_from_callback()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a coroutine which calls async_get_hass from a callback
@callback
def _schedule_coroutine_from_callback() -> None:
async def _coroutine():
assert can_call_async_get_hass()
task_finished.set()
# Test the scheduled callback itself can call async_get_hass
assert can_call_async_get_hass()
hass.async_add_job(_coroutine())
_schedule_coroutine_from_callback()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass from a coroutine
async def _schedule_callback_from_coroutine() -> None:
@callback
def _callback():
assert can_call_async_get_hass()
task_finished.set()
# Test the coroutine itself can call async_get_hass
assert can_call_async_get_hass()
hass.async_add_job(_callback)
await _schedule_callback_from_coroutine()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a coroutine which calls async_get_hass from a coroutine
async def _schedule_callback_from_coroutine() -> None:
async def _coroutine():
assert can_call_async_get_hass()
task_finished.set()
# Test the coroutine itself can call async_get_hass
assert can_call_async_get_hass()
await hass.async_create_task(_coroutine())
await _schedule_callback_from_coroutine()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass from an executor
def _async_add_executor_job_add_job() -> None:
@callback
def _async_add_job():
assert can_call_async_get_hass()
task_finished.set()
# Test the executor itself can not call async_get_hass
assert not can_call_async_get_hass()
hass.add_job(_async_add_job)
await hass.async_add_executor_job(_async_add_executor_job_add_job)
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a coroutine which calls async_get_hass from an executor
def _async_add_executor_job_create_task() -> None:
async def _async_create_task() -> None:
assert can_call_async_get_hass()
task_finished.set()
# Test the executor itself can not call async_get_hass
assert not can_call_async_get_hass()
hass.create_task(_async_create_task())
await hass.async_add_executor_job(_async_add_executor_job_create_task)
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
# Test scheduling a callback which calls async_get_hass from a worker thread
class MyJobAddJob(threading.Thread):
@callback
def _my_threaded_job_add_job(self) -> None:
assert can_call_async_get_hass()
task_finished.set()
def run(self) -> None:
# Test the worker thread itself can not call async_get_hass
assert not can_call_async_get_hass()
hass.add_job(self._my_threaded_job_add_job)
my_job_add_job = MyJobAddJob()
my_job_add_job.start()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
my_job_add_job.join()
# Test scheduling a coroutine which calls async_get_hass from a worker thread
class MyJobCreateTask(threading.Thread):
async def _my_threaded_job_create_task(self) -> None:
assert can_call_async_get_hass()
task_finished.set()
def run(self) -> None:
# Test the worker thread itself can not call async_get_hass
assert not can_call_async_get_hass()
hass.create_task(self._my_threaded_job_create_task())
my_job_create_task = MyJobCreateTask()
my_job_create_task.start()
async with async_timeout.timeout(1):
await task_finished.wait()
task_finished.clear()
my_job_create_task.join()
async def test_stage_shutdown(hass: HomeAssistant) -> None:
"""Simulate a shutdown, test calling stuff."""
test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)