Store runtime data inside ConfigEntry (#115669)

This commit is contained in:
Marc Mueller 2024-04-30 11:29:43 +02:00 committed by GitHub
parent 258e20bfc4
commit dace9b32de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 118 additions and 23 deletions

View File

@ -7,7 +7,7 @@ from dataclasses import dataclass
from adguardhome import AdGuardHome, AdGuardHomeConnectionError
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import (
CONF_HOST,
CONF_NAME,
@ -43,6 +43,7 @@ SERVICE_REFRESH_SCHEMA = vol.Schema(
)
PLATFORMS = [Platform.SENSOR, Platform.SWITCH]
AdGuardConfigEntry = ConfigEntry["AdGuardData"]
@dataclass
@ -53,7 +54,7 @@ class AdGuardData:
version: str
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> bool:
"""Set up AdGuard Home from a config entry."""
session = async_get_clientsession(hass, entry.data[CONF_VERIFY_SSL])
adguard = AdGuardHome(
@ -71,7 +72,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
except AdGuardHomeConnectionError as exception:
raise ConfigEntryNotReady from exception
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = AdGuardData(adguard, version)
entry.runtime_data = AdGuardData(adguard, version)
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
@ -116,17 +117,20 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: AdGuardConfigEntry) -> bool:
"""Unload AdGuard Home config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)
if not hass.data[DOMAIN]:
loaded_entries = [
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# This is the last loaded instance of AdGuard, deregister any services
hass.services.async_remove(DOMAIN, SERVICE_ADD_URL)
hass.services.async_remove(DOMAIN, SERVICE_REMOVE_URL)
hass.services.async_remove(DOMAIN, SERVICE_ENABLE_URL)
hass.services.async_remove(DOMAIN, SERVICE_DISABLE_URL)
hass.services.async_remove(DOMAIN, SERVICE_REFRESH)
del hass.data[DOMAIN]
return unload_ok

View File

@ -4,11 +4,11 @@ from __future__ import annotations
from adguardhome import AdGuardHomeError
from homeassistant.config_entries import SOURCE_HASSIO, ConfigEntry
from homeassistant.config_entries import SOURCE_HASSIO
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity import Entity
from . import AdGuardData
from . import AdGuardConfigEntry, AdGuardData
from .const import DOMAIN, LOGGER
@ -21,7 +21,7 @@ class AdGuardHomeEntity(Entity):
def __init__(
self,
data: AdGuardData,
entry: ConfigEntry,
entry: AdGuardConfigEntry,
) -> None:
"""Initialize the AdGuard Home entity."""
self._entry = entry

View File

@ -10,12 +10,11 @@ from typing import Any
from adguardhome import AdGuardHome
from homeassistant.components.sensor import SensorEntity, SensorEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import PERCENTAGE, UnitOfTime
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import AdGuardData
from . import AdGuardConfigEntry, AdGuardData
from .const import DOMAIN
from .entity import AdGuardHomeEntity
@ -85,11 +84,11 @@ SENSORS: tuple[AdGuardHomeEntityDescription, ...] = (
async def async_setup_entry(
hass: HomeAssistant,
entry: ConfigEntry,
entry: AdGuardConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up AdGuard Home sensor based on a config entry."""
data: AdGuardData = hass.data[DOMAIN][entry.entry_id]
data = entry.runtime_data
async_add_entities(
[AdGuardHomeSensor(data, entry, description) for description in SENSORS],
@ -105,7 +104,7 @@ class AdGuardHomeSensor(AdGuardHomeEntity, SensorEntity):
def __init__(
self,
data: AdGuardData,
entry: ConfigEntry,
entry: AdGuardConfigEntry,
description: AdGuardHomeEntityDescription,
) -> None:
"""Initialize AdGuard Home sensor."""

View File

@ -10,11 +10,10 @@ from typing import Any
from adguardhome import AdGuardHome, AdGuardHomeError
from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import AdGuardData
from . import AdGuardConfigEntry, AdGuardData
from .const import DOMAIN, LOGGER
from .entity import AdGuardHomeEntity
@ -79,11 +78,11 @@ SWITCHES: tuple[AdGuardHomeSwitchEntityDescription, ...] = (
async def async_setup_entry(
hass: HomeAssistant,
entry: ConfigEntry,
entry: AdGuardConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up AdGuard Home switch based on a config entry."""
data: AdGuardData = hass.data[DOMAIN][entry.entry_id]
data = entry.runtime_data
async_add_entities(
[AdGuardHomeSwitch(data, entry, description) for description in SWITCHES],
@ -99,7 +98,7 @@ class AdGuardHomeSwitch(AdGuardHomeEntity, SwitchEntity):
def __init__(
self,
data: AdGuardData,
entry: ConfigEntry,
entry: AdGuardConfigEntry,
description: AdGuardHomeSwitchEntityDescription,
) -> None:
"""Initialize AdGuard Home switch."""

View File

@ -21,9 +21,10 @@ from functools import cached_property
import logging
from random import randint
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Self, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Self, cast
from async_interrupt import interrupt
from typing_extensions import TypeVar
from . import data_entry_flow, loader
from .components import persistent_notification
@ -124,6 +125,7 @@ SAVE_DELAY = 1
DISCOVERY_COOLDOWN = 1
_DataT = TypeVar("_DataT", default=Any)
_R = TypeVar("_R")
@ -266,13 +268,14 @@ class ConfigFlowResult(FlowResult, total=False):
version: int
class ConfigEntry:
class ConfigEntry(Generic[_DataT]):
"""Hold a configuration entry."""
entry_id: str
domain: str
title: str
data: MappingProxyType[str, Any]
runtime_data: _DataT
options: MappingProxyType[str, Any]
unique_id: str | None
state: ConfigEntryState

View File

@ -23,6 +23,10 @@ _COMMON_ARGUMENTS: dict[str, list[str]] = {
"hass": ["HomeAssistant", "HomeAssistant | None"]
}
_PLATFORMS: set[str] = {platform.value for platform in Platform}
_KNOWN_GENERIC_TYPES: set[str] = {
"ConfigEntry",
}
_KNOWN_GENERIC_TYPES_TUPLE = tuple(_KNOWN_GENERIC_TYPES)
class _Special(Enum):
@ -2977,6 +2981,16 @@ def _is_valid_type(
):
return True
# Allow subscripts or type aliases for generic types
if (
isinstance(node, nodes.Subscript)
and isinstance(node.value, nodes.Name)
and node.value.name in _KNOWN_GENERIC_TYPES
or isinstance(node, nodes.Name)
and node.name.endswith(_KNOWN_GENERIC_TYPES_TUPLE)
):
return True
# Name occurs when a namespace is not used, eg. "HomeAssistant"
if isinstance(node, nodes.Name) and node.name == expected_type:
return True

View File

@ -1196,3 +1196,79 @@ def test_pytest_invalid_function(
),
):
type_hint_checker.visit_asyncfunctiondef(func_node)
@pytest.mark.parametrize(
"entry_annotation",
[
"ConfigEntry",
"ConfigEntry[AdGuardData]",
"AdGuardConfigEntry", # prefix allowed for type aliases
],
)
def test_valid_generic(
linter: UnittestLinter, type_hint_checker: BaseChecker, entry_annotation: str
) -> None:
"""Ensure valid hints are accepted for generic types."""
func_node = astroid.extract_node(
f"""
async def async_setup_entry( #@
hass: HomeAssistant,
entry: {entry_annotation},
async_add_entities: AddEntitiesCallback,
) -> None:
pass
""",
"homeassistant.components.pylint_test.notify",
)
type_hint_checker.visit_module(func_node.parent)
with assert_no_messages(linter):
type_hint_checker.visit_asyncfunctiondef(func_node)
@pytest.mark.parametrize(
("entry_annotation", "end_col_offset"),
[
("Config", 17), # not generic
("ConfigEntryXX[Data]", 30), # generic type needs to match exactly
("ConfigEntryData", 26), # ConfigEntry should be the suffix
],
)
def test_invalid_generic(
linter: UnittestLinter,
type_hint_checker: BaseChecker,
entry_annotation: str,
end_col_offset: int,
) -> None:
"""Ensure invalid hints are rejected for generic types."""
func_node, entry_node = astroid.extract_node(
f"""
async def async_setup_entry( #@
hass: HomeAssistant,
entry: {entry_annotation}, #@
async_add_entities: AddEntitiesCallback,
) -> None:
pass
""",
"homeassistant.components.pylint_test.notify",
)
type_hint_checker.visit_module(func_node.parent)
with assert_adds_messages(
linter,
pylint.testutils.MessageTest(
msg_id="hass-argument-type",
node=entry_node,
args=(
2,
"ConfigEntry",
"async_setup_entry",
),
line=4,
col_offset=4,
end_line=4,
end_col_offset=end_col_offset,
),
):
type_hint_checker.visit_asyncfunctiondef(func_node)