Add ComponentProtocol to improve type checking (#90586)

This commit is contained in:
epenet 2023-03-31 20:19:58 +02:00 committed by GitHub
parent 03137feba5
commit 611d4135fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 77 additions and 17 deletions

View File

@ -61,7 +61,7 @@ from .helpers import (
)
from .helpers.entity_values import EntityValues
from .helpers.typing import ConfigType
from .loader import Integration, IntegrationNotFound
from .loader import ComponentProtocol, Integration, IntegrationNotFound
from .requirements import RequirementsNotFound, async_get_integration_with_requirements
from .util.package import is_docker_env
from .util.unit_system import get_unit_system, validate_unit_system
@ -681,7 +681,7 @@ def _log_pkg_error(package: str, component: str, config: dict, message: str) ->
_LOGGER.error(message)
def _identify_config_schema(module: ModuleType) -> str | None:
def _identify_config_schema(module: ComponentProtocol) -> str | None:
"""Extract the schema and identify list or dict based."""
if not isinstance(module.CONFIG_SCHEMA, vol.Schema):
return None

View File

@ -383,7 +383,7 @@ class ConfigEntry:
result = await component.async_setup_entry(hass, self)
if not isinstance(result, bool):
_LOGGER.error(
_LOGGER.error( # type: ignore[unreachable]
"%s.async_setup_entry did not return boolean", integration.domain
)
result = False
@ -546,8 +546,7 @@ class ConfigEntry:
await self._async_process_on_unload()
# https://github.com/python/mypy/issues/11839
return result # type: ignore[no-any-return]
return result
except Exception as ex: # pylint: disable=broad-except
_LOGGER.exception(
"Error unloading entry %s for %s", self.title, integration.domain
@ -628,15 +627,14 @@ class ConfigEntry:
try:
result = await component.async_migrate_entry(hass, self)
if not isinstance(result, bool):
_LOGGER.error(
_LOGGER.error( # type: ignore[unreachable]
"%s.async_migrate_entry did not return boolean", self.domain
)
return False
if result:
# pylint: disable-next=protected-access
hass.config_entries._async_schedule_save()
# https://github.com/python/mypy/issues/11839
return result # type: ignore[no-any-return]
return result
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error migrating entry %s for %s", self.title, self.domain

View File

@ -15,13 +15,14 @@ import logging
import pathlib
import sys
from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, TypeVar, cast
from awesomeversion import (
AwesomeVersion,
AwesomeVersionException,
AwesomeVersionStrategy,
)
import voluptuous as vol
from . import generated
from .generated.application_credentials import APPLICATION_CREDENTIALS
@ -35,7 +36,10 @@ from .util.json import JSON_DECODE_EXCEPTIONS, json_loads
# Typing imports that create a circular dependency
if TYPE_CHECKING:
from .config_entries import ConfigEntry
from .core import HomeAssistant
from .helpers import device_registry as dr
from .helpers.typing import ConfigType
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
@ -260,6 +264,52 @@ async def async_get_config_flows(
return flows
class ComponentProtocol(Protocol):
"""Define the format of an integration."""
CONFIG_SCHEMA: vol.Schema
DOMAIN: str
async def async_setup_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up a config entry."""
async def async_unload_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload a config entry."""
async def async_migrate_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Migrate an old config entry."""
async def async_remove_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> None:
"""Remove a config entry."""
async def async_remove_config_entry_device(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
device_entry: dr.DeviceEntry,
) -> bool:
"""Remove a config entry device."""
async def async_reset_platform(
self, hass: HomeAssistant, integration_name: str
) -> None:
"""Release resources."""
async def async_setup(self, hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up integration."""
def setup(self, hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up integration."""
async def async_get_integration_descriptions(
hass: HomeAssistant,
) -> dict[str, Any]:
@ -750,14 +800,18 @@ class Integration:
return self._all_dependencies_resolved
def get_component(self) -> ModuleType:
def get_component(self) -> ComponentProtocol:
"""Return the component."""
cache: dict[str, ModuleType] = self.hass.data.setdefault(DATA_COMPONENTS, {})
cache: dict[str, ComponentProtocol] = self.hass.data.setdefault(
DATA_COMPONENTS, {}
)
if self.domain in cache:
return cache[self.domain]
try:
cache[self.domain] = importlib.import_module(self.pkg_path)
cache[self.domain] = cast(
ComponentProtocol, importlib.import_module(self.pkg_path)
)
except ImportError:
raise
except Exception as err:
@ -922,7 +976,7 @@ class CircularDependency(LoaderError):
def _load_file(
hass: HomeAssistant, comp_or_platform: str, base_paths: list[str]
) -> ModuleType | None:
) -> ComponentProtocol | None:
"""Try to load specified file.
Looks in config dir first, then built-in components.
@ -957,7 +1011,7 @@ def _load_file(
cache[comp_or_platform] = module
return module
return cast(ComponentProtocol, module)
except ImportError as err:
# This error happens if for example custom_components/switch
@ -981,7 +1035,7 @@ def _load_file(
class ModuleWrapper:
"""Class to wrap a Python module and auto fill in hass argument."""
def __init__(self, hass: HomeAssistant, module: ModuleType) -> None:
def __init__(self, hass: HomeAssistant, module: ComponentProtocol) -> None:
"""Initialize the module wrapper."""
self._hass = hass
self._module = module
@ -1010,7 +1064,7 @@ class Components:
integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name)
if isinstance(integration, Integration):
component: ModuleType | None = integration.get_component()
component: ComponentProtocol | None = integration.get_component()
else:
# Fallback to importing old-school
component = _load_file(self._hass, comp_name, _lookup_path(self._hass))

View File

@ -236,7 +236,7 @@ async def _async_setup_component(
SLOW_SETUP_WARNING,
)
task = None
task: Awaitable[bool] | None = None
result: Any | bool = True
try:
if hasattr(component, "async_setup"):

View File

@ -202,6 +202,14 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
},
return_type="bool",
),
TypeHintMatch(
function_name="async_reset_platform",
arg_types={
0: "HomeAssistant",
1: "str",
},
return_type=None,
),
],
"__any_platform__": [
TypeHintMatch(