Improve typing [util.decorator] (#67087)

This commit is contained in:
Marc Mueller 2022-02-23 20:58:42 +01:00 committed by GitHub
parent 46c2bd0eb0
commit ec980a574b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 59 additions and 38 deletions

View File

@ -22,6 +22,7 @@ homeassistant.helpers.script_variables
homeassistant.helpers.translation
homeassistant.util.async_
homeassistant.util.color
homeassistant.util.decorator
homeassistant.util.process
homeassistant.util.unit_system

View File

@ -16,7 +16,7 @@ from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.decorator import Registry
MULTI_FACTOR_AUTH_MODULES = Registry()
MULTI_FACTOR_AUTH_MODULES: Registry[str, type[MultiFactorAuthModule]] = Registry()
MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema(
{
@ -129,7 +129,7 @@ async def auth_mfa_module_from_config(
hass: HomeAssistant, config: dict[str, Any]
) -> MultiFactorAuthModule:
"""Initialize an auth module from a config."""
module_name = config[CONF_TYPE]
module_name: str = config[CONF_TYPE]
module = await _load_mfa_module(hass, module_name)
try:
@ -142,7 +142,7 @@ async def auth_mfa_module_from_config(
)
raise
return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config) # type: ignore[no-any-return]
return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config)
async def _load_mfa_module(hass: HomeAssistant, module_name: str) -> types.ModuleType:

View File

@ -25,7 +25,7 @@ from ..models import Credentials, RefreshToken, User, UserMeta
_LOGGER = logging.getLogger(__name__)
DATA_REQS = "auth_prov_reqs_processed"
AUTH_PROVIDERS = Registry()
AUTH_PROVIDERS: Registry[str, type[AuthProvider]] = Registry()
AUTH_PROVIDER_SCHEMA = vol.Schema(
{
@ -136,7 +136,7 @@ async def auth_provider_from_config(
hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
) -> AuthProvider:
"""Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE]
provider_name: str = config[CONF_TYPE]
module = await load_auth_provider_module(hass, provider_name)
try:
@ -149,7 +149,7 @@ async def auth_provider_from_config(
)
raise
return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore[no-any-return]
return AUTH_PROVIDERS[provider_name](hass, store, config)
async def load_auth_provider_module(

View File

@ -83,7 +83,7 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__)
ENTITY_ADAPTERS = Registry()
ENTITY_ADAPTERS: Registry[str, type[AlexaEntity]] = Registry()
TRANSLATION_TABLE = dict.fromkeys(map(ord, r"}{\/|\"()[]+~!><*%"), None)

View File

@ -73,7 +73,7 @@ from .errors import (
from .state_report import async_enable_proactive_mode
_LOGGER = logging.getLogger(__name__)
HANDLERS = Registry()
HANDLERS = Registry() # type: ignore[var-annotated]
@HANDLERS.register(("Alexa.Discovery", "Discover"))

View File

@ -12,7 +12,7 @@ from .const import DOMAIN, SYN_RESOLUTION_MATCH
_LOGGER = logging.getLogger(__name__)
HANDLERS = Registry()
HANDLERS = Registry() # type: ignore[var-annotated]
INTENTS_API_ENDPOINT = "/api/alexa"

View File

@ -52,7 +52,7 @@ FILTER_NAME_OUTLIER = "outlier"
FILTER_NAME_THROTTLE = "throttle"
FILTER_NAME_TIME_THROTTLE = "time_throttle"
FILTER_NAME_TIME_SMA = "time_simple_moving_average"
FILTERS = Registry()
FILTERS: Registry[str, type[Filter]] = Registry()
CONF_FILTERS = "filters"
CONF_FILTER_NAME = "filter"

View File

@ -19,7 +19,7 @@ from .helpers import GoogleEntity, RequestData, async_get_entities
EXECUTE_LIMIT = 2 # Wait 2 seconds for execute to finish
HANDLERS = Registry()
HANDLERS = Registry() # type: ignore[var-annotated]
_LOGGER = logging.getLogger(__name__)

View File

@ -1,4 +1,6 @@
"""Extend the basic Accessory and Bridge functions."""
from __future__ import annotations
import logging
from pyhap.accessory import Accessory, Bridge
@ -90,7 +92,7 @@ SWITCH_TYPES = {
TYPE_SWITCH: "Switch",
TYPE_VALVE: "Valve",
}
TYPES = Registry()
TYPES: Registry[str, type[HomeAccessory]] = Registry()
def get_accessory(hass, driver, state, aid, config): # noqa: C901

View File

@ -9,7 +9,7 @@ from homeassistant.util import decorator
from .const import CONF_INVERSE, SIGNAL_DS18B20_NEW
_LOGGER = logging.getLogger(__name__)
HANDLERS = decorator.Registry()
HANDLERS = decorator.Registry() # type: ignore[var-annotated]
@HANDLERS.register("state")

View File

@ -109,7 +109,7 @@ _LOGGER = logging.getLogger(__name__)
DELAY_SAVE = 10
WEBHOOK_COMMANDS = Registry()
WEBHOOK_COMMANDS = Registry() # type: ignore[var-annotated]
COMBINED_CLASSES = set(BINARY_SENSOR_CLASSES + SENSOR_CLASSES)
SENSOR_TYPES = [ATTR_SENSOR_TYPE_BINARY_SENSOR, ATTR_SENSOR_TYPE_SENSOR]

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Callable, Coroutine
from collections.abc import Callable
import logging
import socket
import sys
@ -337,9 +337,7 @@ def _gw_callback_factory(
_LOGGER.debug("Node update: node %s child %s", msg.node_id, msg.child_id)
msg_type = msg.gateway.const.MessageType(msg.type)
msg_handler: Callable[
[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]
] | None = HANDLERS.get(msg_type.name)
msg_handler = HANDLERS.get(msg_type.name)
if msg_handler is None:
return

View File

@ -1,6 +1,9 @@
"""Handle MySensors messages."""
from __future__ import annotations
from collections.abc import Callable, Coroutine
from typing import Any
from mysensors import Message
from homeassistant.const import Platform
@ -12,7 +15,9 @@ from .const import CHILD_CALLBACK, NODE_CALLBACK, DevId, GatewayId
from .device import get_mysensors_devices
from .helpers import discover_mysensors_platform, validate_set_msg
HANDLERS = decorator.Registry()
HANDLERS: decorator.Registry[
str, Callable[[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]]
] = decorator.Registry()
@HANDLERS.register("set")

View File

@ -31,7 +31,9 @@ from .const import (
)
_LOGGER = logging.getLogger(__name__)
SCHEMAS = Registry()
SCHEMAS: Registry[
tuple[str, str], Callable[[BaseAsyncGateway, ChildSensor, ValueType], vol.Schema]
] = Registry()
@callback

View File

@ -1,10 +1,13 @@
"""ONVIF event parsers."""
from collections.abc import Callable, Coroutine
from typing import Any
from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry
from .models import Event
PARSERS = Registry()
PARSERS: Registry[str, Callable[[str, Any], Coroutine[Any, Any, Event]]] = Registry()
@PARSERS.register("tns1:VideoSource/MotionAlarm")

View File

@ -1,8 +1,10 @@
"""Helpers to help coordinate updates."""
from __future__ import annotations
from collections.abc import Callable, Coroutine
from datetime import timedelta
import logging
from typing import Any
from aiohttp import ServerDisconnectedError
from pyoverkiz.client import OverkizClient
@ -25,7 +27,9 @@ from homeassistant.util.decorator import Registry
from .const import DOMAIN, LOGGER, UPDATE_INTERVAL
EVENT_HANDLERS = Registry()
EVENT_HANDLERS: Registry[
str, Callable[[OverkizDataUpdateCoordinator, Event], Coroutine[Any, Any, None]]
] = Registry()
class OverkizDataUpdateCoordinator(DataUpdateCoordinator[dict[str, Device]]):

View File

@ -17,7 +17,7 @@ from .helper import supports_encryption
_LOGGER = logging.getLogger(__name__)
HANDLERS = decorator.Registry()
HANDLERS = decorator.Registry() # type: ignore[var-annotated]
def get_cipher():

View File

@ -245,7 +245,7 @@ class Stream:
self, fmt: str, timeout: int = OUTPUT_IDLE_TIMEOUT
) -> StreamOutput:
"""Add provider output stream."""
if not self._outputs.get(fmt):
if not (provider := self._outputs.get(fmt)):
@callback
def idle_callback() -> None:
@ -259,7 +259,7 @@ class Stream:
self.hass, IdleTimer(self.hass, timeout, idle_callback)
)
self._outputs[fmt] = provider
return self._outputs[fmt]
return provider
def remove_provider(self, provider: StreamOutput) -> None:
"""Remove provider output stream."""

View File

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from . import Stream
PROVIDERS = Registry()
PROVIDERS: Registry[str, type[StreamOutput]] = Registry()
@attr.s(slots=True)

View File

@ -62,7 +62,7 @@ SOURCE_UNIGNORE = "unignore"
# This is used to signal that re-authentication is required by the user.
SOURCE_REAUTH = "reauth"
HANDLERS = Registry()
HANDLERS: Registry[str, type[ConfigFlow]] = Registry()
STORAGE_KEY = "core.config_entries"
STORAGE_VERSION = 1
@ -530,8 +530,10 @@ class ConfigEntry:
)
return False
# Handler may be a partial
# Keep for backwards compatibility
# https://github.com/home-assistant/core/pull/67087#discussion_r812559950
while isinstance(handler, functools.partial):
handler = handler.func
handler = handler.func # type: ignore[unreachable]
if self.version == handler.VERSION:
return True
@ -753,7 +755,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set")
flow = cast(ConfigFlow, handler())
flow = handler()
flow.init_step = context["source"]
return flow
@ -1496,7 +1498,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
if entry.domain not in HANDLERS:
raise data_entry_flow.UnknownHandler
return cast(OptionsFlow, HANDLERS[entry.domain].async_get_options_flow(entry))
return HANDLERS[entry.domain].async_get_options_flow(entry)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult

View File

@ -13,7 +13,7 @@ from homeassistant.util import decorator
from . import config_validation as cv
SELECTORS = decorator.Registry()
SELECTORS: decorator.Registry[str, type[Selector]] = decorator.Registry()
def _get_selector_class(config: Any) -> type[Selector]:
@ -24,12 +24,12 @@ def _get_selector_class(config: Any) -> type[Selector]:
if len(config) != 1:
raise vol.Invalid(f"Only one type can be specified. Found {', '.join(config)}")
selector_type = list(config)[0]
selector_type: str = list(config)[0]
if (selector_class := SELECTORS.get(selector_type)) is None:
raise vol.Invalid(f"Unknown selector type {selector_type} found")
return cast(type[Selector], selector_class)
return selector_class
def selector(config: Any) -> Selector:

View File

@ -2,18 +2,19 @@
from __future__ import annotations
from collections.abc import Callable, Hashable
from typing import TypeVar
from typing import Any, TypeVar
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
_KT = TypeVar("_KT", bound=Hashable)
_VT = TypeVar("_VT", bound=Callable[..., Any])
class Registry(dict):
class Registry(dict[_KT, _VT]):
"""Registry of items."""
def register(self, name: Hashable) -> Callable[[CALLABLE_T], CALLABLE_T]:
def register(self, name: _KT) -> Callable[[_VT], _VT]:
"""Return decorator to register item with a specific name."""
def decorator(func: CALLABLE_T) -> CALLABLE_T:
def decorator(func: _VT) -> _VT:
"""Register decorated function."""
self[name] = func
return func

View File

@ -76,6 +76,9 @@ disallow_any_generics = true
[mypy-homeassistant.util.color]
disallow_any_generics = true
[mypy-homeassistant.util.decorator]
disallow_any_generics = true
[mypy-homeassistant.util.process]
disallow_any_generics = true