Improve recorder type hints in tests (#87826)

* Improve recorder type hints in tests

* Add comment

* Adjust comment
This commit is contained in:
epenet 2023-02-10 11:11:39 +01:00 committed by GitHub
parent b5dfd83c46
commit fac746c974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 30 deletions

View File

@ -94,17 +94,24 @@ _TEST_FIXTURES: dict[str, list[str] | str] = {
"aioclient_mock": "AiohttpClientMocker",
"aiohttp_client": "ClientSessionGenerator",
"area_registry": "AreaRegistry",
"async_setup_recorder_instance": "RecorderInstanceGenerator",
"caplog": "pytest.LogCaptureFixture",
"device_registry": "DeviceRegistry",
"enable_nightly_purge": "bool",
"enable_statistics": "bool",
"enable_statistics_table_validation": "bool",
"entity_registry": "EntityRegistry",
"hass_client": "ClientSessionGenerator",
"hass_client_no_auth": "ClientSessionGenerator",
"hass_recorder": "Callable[..., HomeAssistant]",
"hass_ws_client": "WebSocketGenerator",
"issue_registry": "IssueRegistry",
"mqtt_client_mock": "MqttMockPahoClient",
"mqtt_mock": "MqttMockHAClient",
"mqtt_mock_entry_no_yaml_config": "MqttMockHAClientGenerator",
"mqtt_mock_entry_with_yaml_config": "MqttMockHAClientGenerator",
"recorder_db_url": "str",
"recorder_mock": "Recorder",
}
_TEST_FUNCTION_MATCH = TypeHintMatch(
function_name="test_*",

View File

@ -13,7 +13,7 @@ import logging
import sqlite3
import ssl
import threading
from typing import Any
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from aiohttp import ClientWebSocketResponse, client
@ -65,9 +65,15 @@ from .typing import (
MqttMockHAClient,
MqttMockHAClientGenerator,
MqttMockPahoClient,
RecorderInstanceGenerator,
WebSocketGenerator,
)
if TYPE_CHECKING:
# Local import to avoid processing recorder and SQLite modules when running a
# testcase which does not use the recorder.
from homeassistant.components import recorder
pytest.register_assert_rewrite("tests.common")
from .common import ( # noqa: E402, isort:skip
@ -75,7 +81,6 @@ from .common import ( # noqa: E402, isort:skip
INSTANCES,
MockConfigEntry,
MockUser,
SetupRecorderInstanceT,
async_fire_mqtt_message,
async_test_home_assistant,
get_test_home_assistant,
@ -994,7 +999,7 @@ def enable_custom_integrations(hass):
@pytest.fixture
def enable_statistics():
def enable_statistics() -> bool:
"""Fixture to control enabling of recorder's statistics compilation.
To enable statistics, tests can be marked with:
@ -1004,7 +1009,7 @@ def enable_statistics():
@pytest.fixture
def enable_statistics_table_validation():
def enable_statistics_table_validation() -> bool:
"""Fixture to control enabling of recorder's statistics table validation.
To enable statistics table validation, tests can be marked with:
@ -1014,7 +1019,7 @@ def enable_statistics_table_validation():
@pytest.fixture
def enable_nightly_purge():
def enable_nightly_purge() -> bool:
"""Fixture to control enabling of recorder's nightly purge job.
To enable nightly purging, tests can be marked with:
@ -1024,7 +1029,7 @@ def enable_nightly_purge():
@pytest.fixture
def recorder_config():
def recorder_config() -> dict[str, Any] | None:
"""Fixture to override recorder config.
To override the config, tests can be marked with:
@ -1035,14 +1040,15 @@ def recorder_config():
@pytest.fixture
def recorder_db_url(
pytestconfig,
hass_fixture_setup,
):
pytestconfig: pytest.Config,
hass_fixture_setup: list[bool],
) -> Generator[str, None, None]:
"""Prepare a default database for tests and return a connection URL."""
assert not hass_fixture_setup
db_url: str = pytestconfig.getoption("dburl")
db_url = cast(str, pytestconfig.getoption("dburl"))
if db_url.startswith(("postgresql://", "mysql://")):
# pylint: disable-next=import-outside-toplevel
import sqlalchemy_utils
def _ha_orm_quote(mixed, ident):
@ -1060,18 +1066,21 @@ def recorder_db_url(
sqlalchemy_utils.functions.database.quote = _ha_orm_quote
if db_url.startswith("mysql://"):
# pylint: disable-next=import-outside-toplevel
import sqlalchemy_utils
charset = "utf8mb4' COLLATE = 'utf8mb4_unicode_ci"
assert not sqlalchemy_utils.database_exists(db_url)
sqlalchemy_utils.create_database(db_url, encoding=charset)
elif db_url.startswith("postgresql://"):
# pylint: disable-next=import-outside-toplevel
import sqlalchemy_utils
assert not sqlalchemy_utils.database_exists(db_url)
sqlalchemy_utils.create_database(db_url, encoding="utf8")
yield db_url
if db_url.startswith("mysql://"):
# pylint: disable-next=import-outside-toplevel
import sqlalchemy as sa
made_url = sa.make_url(db_url)
@ -1096,15 +1105,14 @@ def recorder_db_url(
@pytest.fixture
def hass_recorder(
recorder_db_url,
enable_nightly_purge,
enable_statistics,
enable_statistics_table_validation,
recorder_db_url: str,
enable_nightly_purge: bool,
enable_statistics: bool,
enable_statistics_table_validation: bool,
hass_storage,
):
) -> Generator[Callable[..., HomeAssistant], None, None]:
"""Home Assistant fixture with in-memory recorder."""
# Local import to avoid processing recorder and SQLite modules when running a
# testcase which does not use the recorder.
# pylint: disable-next=import-outside-toplevel
from homeassistant.components import recorder
original_tz = dt_util.DEFAULT_TIME_ZONE
@ -1131,7 +1139,7 @@ def hass_recorder(
autospec=True,
):
def setup_recorder(config=None):
def setup_recorder(config: dict[str, Any] | None = None) -> HomeAssistant:
"""Set up with params."""
init_recorder_component(hass, config, recorder_db_url)
hass.start()
@ -1146,10 +1154,13 @@ def hass_recorder(
dt_util.DEFAULT_TIME_ZONE = original_tz
async def _async_init_recorder_component(hass, add_config=None, db_url=None):
async def _async_init_recorder_component(
hass: HomeAssistant,
add_config: dict[str, Any] | None = None,
db_url: str | None = None,
) -> None:
"""Initialize the recorder asynchronously."""
# Local import to avoid processing recorder and SQLite modules when running a
# testcase which does not use the recorder.
# pylint: disable-next=import-outside-toplevel
from homeassistant.components import recorder
config = dict(add_config) if add_config else {}
@ -1173,16 +1184,16 @@ async def _async_init_recorder_component(hass, add_config=None, db_url=None):
@pytest.fixture
async def async_setup_recorder_instance(
recorder_db_url,
enable_nightly_purge,
enable_statistics,
enable_statistics_table_validation,
) -> AsyncGenerator[SetupRecorderInstanceT, None]:
recorder_db_url: str,
enable_nightly_purge: bool,
enable_statistics: bool,
enable_statistics_table_validation: bool,
) -> AsyncGenerator[RecorderInstanceGenerator, None]:
"""Yield callable to setup recorder instance."""
# Local import to avoid processing recorder and SQLite modules when running a
# testcase which does not use the recorder.
# pylint: disable-next=import-outside-toplevel
from homeassistant.components import recorder
# pylint: disable-next=import-outside-toplevel
from .components.recorder.common import async_recorder_block_till_done
nightly = recorder.Recorder.async_nightly_tasks if enable_nightly_purge else None
@ -1222,7 +1233,11 @@ async def async_setup_recorder_instance(
@pytest.fixture
async def recorder_mock(recorder_config, async_setup_recorder_instance, hass):
async def recorder_mock(
recorder_config: dict[str, Any] | None,
async_setup_recorder_instance: RecorderInstanceGenerator,
hass: HomeAssistant,
) -> recorder.Recorder:
"""Fixture with in-memory recorder."""
return await async_setup_recorder_instance(hass, recorder_config)

View File

@ -2,12 +2,17 @@
from __future__ import annotations
from collections.abc import Callable, Coroutine
from typing import Any
from typing import TYPE_CHECKING, Any, TypeAlias
from unittest.mock import MagicMock
from aiohttp import ClientWebSocketResponse
from aiohttp.test_utils import TestClient
if TYPE_CHECKING:
# Local import to avoid processing recorder module when running a
# testcase which does not use the recorder.
from homeassistant.components.recorder import Recorder
ClientSessionGenerator = Callable[..., Coroutine[Any, Any, TestClient]]
MqttMockPahoClient = MagicMock
"""MagicMock for `paho.mqtt.client.Client`"""
@ -15,4 +20,6 @@ MqttMockHAClient = MagicMock
"""MagicMock for `homeassistant.components.mqtt.MQTT`."""
MqttMockHAClientGenerator = Callable[..., Coroutine[Any, Any, MqttMockHAClient]]
"""MagicMock generator for `homeassistant.components.mqtt.MQTT`."""
RecorderInstanceGenerator: TypeAlias = Callable[..., Coroutine[Any, Any, "Recorder"]]
"""Instance generator for `homeassistant.components.recorder.Recorder`."""
WebSocketGenerator = Callable[..., Coroutine[Any, Any, ClientWebSocketResponse]]