Remove HomeAssistantType alias from helpers (#48400)

This commit is contained in:
Franck Nijhof 2021-03-27 12:55:24 +01:00 committed by GitHub
parent 4a353efdfb
commit 38d14702fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 141 additions and 157 deletions

View File

@ -14,9 +14,8 @@ from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
import async_timeout
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.core import Event, callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.frame import warn_use
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass
from homeassistant.util import ssl as ssl_util
@ -32,7 +31,7 @@ SERVER_SOFTWARE = "HomeAssistant/{0} aiohttp/{1} Python/{2[0]}.{2[1]}".format(
@callback
@bind_hass
def async_get_clientsession(
hass: HomeAssistantType, verify_ssl: bool = True
hass: HomeAssistant, verify_ssl: bool = True
) -> aiohttp.ClientSession:
"""Return default aiohttp ClientSession.
@ -51,7 +50,7 @@ def async_get_clientsession(
@callback
@bind_hass
def async_create_clientsession(
hass: HomeAssistantType,
hass: HomeAssistant,
verify_ssl: bool = True,
auto_cleanup: bool = True,
**kwargs: Any,
@ -84,7 +83,7 @@ def async_create_clientsession(
@bind_hass
async def async_aiohttp_proxy_web(
hass: HomeAssistantType,
hass: HomeAssistant,
request: web.BaseRequest,
web_coro: Awaitable[aiohttp.ClientResponse],
buffer_size: int = 102400,
@ -117,7 +116,7 @@ async def async_aiohttp_proxy_web(
@bind_hass
async def async_aiohttp_proxy_stream(
hass: HomeAssistantType,
hass: HomeAssistant,
request: web.BaseRequest,
stream: aiohttp.StreamReader,
content_type: str | None,
@ -145,7 +144,7 @@ async def async_aiohttp_proxy_stream(
@callback
def _async_register_clientsession_shutdown(
hass: HomeAssistantType, clientsession: aiohttp.ClientSession
hass: HomeAssistant, clientsession: aiohttp.ClientSession
) -> None:
"""Register ClientSession close on Home Assistant shutdown.
@ -162,7 +161,7 @@ def _async_register_clientsession_shutdown(
@callback
def _async_get_connector(
hass: HomeAssistantType, verify_ssl: bool = True
hass: HomeAssistant, verify_ssl: bool = True
) -> aiohttp.BaseConnector:
"""Return the connector pool for aiohttp.

View File

@ -6,13 +6,11 @@ from typing import Container, Iterable, MutableMapping, cast
import attr
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.loader import bind_hass
from homeassistant.util import slugify
from .typing import HomeAssistantType
# mypy: disallow-any-generics
DATA_REGISTRY = "area_registry"
@ -43,7 +41,7 @@ class AreaEntry:
class AreaRegistry:
"""Class to hold a registry of areas."""
def __init__(self, hass: HomeAssistantType) -> None:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the area registry."""
self.hass = hass
self.areas: MutableMapping[str, AreaEntry] = {}
@ -186,12 +184,12 @@ class AreaRegistry:
@callback
def async_get(hass: HomeAssistantType) -> AreaRegistry:
def async_get(hass: HomeAssistant) -> AreaRegistry:
"""Get area registry."""
return cast(AreaRegistry, hass.data[DATA_REGISTRY])
async def async_load(hass: HomeAssistantType) -> None:
async def async_load(hass: HomeAssistant) -> None:
"""Load area registry."""
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = AreaRegistry(hass)
@ -199,7 +197,7 @@ async def async_load(hass: HomeAssistantType) -> None:
@bind_hass
async def async_get_registry(hass: HomeAssistantType) -> AreaRegistry:
async def async_get_registry(hass: HomeAssistant) -> AreaRegistry:
"""Get area registry.
This is deprecated and will be removed in the future. Use async_get instead.

View File

@ -18,7 +18,6 @@ from homeassistant.helpers import entity_registry
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.util import slugify
STORAGE_VERSION = 1
@ -303,7 +302,7 @@ class IDLessCollection(ObservableCollection):
@callback
def sync_entity_lifecycle(
hass: HomeAssistantType,
hass: HomeAssistant,
domain: str,
platform: str,
entity_component: EntityComponent,

View File

@ -4,8 +4,7 @@ from __future__ import annotations
from typing import Any, Awaitable, Callable, Union
from homeassistant import config_entries
from .typing import HomeAssistantType
from homeassistant.core import HomeAssistant
DiscoveryFunctionType = Callable[[], Union[Awaitable[bool], bool]]
@ -182,7 +181,7 @@ def register_webhook_flow(
async def webhook_async_remove_entry(
hass: HomeAssistantType, entry: config_entries.ConfigEntry
hass: HomeAssistant, entry: config_entries.ConfigEntry
) -> None:
"""Remove a webhook config entry."""
if not entry.data.get("cloudhook") or "cloud" not in hass.config.components:

View File

@ -9,12 +9,12 @@ from typing import TYPE_CHECKING, Any, cast
import attr
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import Event, callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.loader import bind_hass
import homeassistant.util.uuid as uuid_util
from .debounce import Debouncer
from .typing import UNDEFINED, HomeAssistantType, UndefinedType
from .typing import UNDEFINED, UndefinedType
# mypy: disallow_any_generics
@ -139,7 +139,7 @@ class DeviceRegistry:
deleted_devices: dict[str, DeletedDeviceEntry]
_devices_index: dict[str, dict[str, dict[tuple[str, str], str]]]
def __init__(self, hass: HomeAssistantType) -> None:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the device registry."""
self.hass = hass
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@ -617,12 +617,12 @@ class DeviceRegistry:
@callback
def async_get(hass: HomeAssistantType) -> DeviceRegistry:
def async_get(hass: HomeAssistant) -> DeviceRegistry:
"""Get device registry."""
return cast(DeviceRegistry, hass.data[DATA_REGISTRY])
async def async_load(hass: HomeAssistantType) -> None:
async def async_load(hass: HomeAssistant) -> None:
"""Load device registry."""
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = DeviceRegistry(hass)
@ -630,7 +630,7 @@ async def async_load(hass: HomeAssistantType) -> None:
@bind_hass
async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry:
async def async_get_registry(hass: HomeAssistant) -> DeviceRegistry:
"""Get device registry.
This is deprecated and will be removed in the future. Use async_get instead.
@ -686,7 +686,7 @@ def async_config_entry_disabled_by_changed(
@callback
def async_cleanup(
hass: HomeAssistantType,
hass: HomeAssistant,
dev_reg: DeviceRegistry,
ent_reg: entity_registry.EntityRegistry,
) -> None:
@ -723,7 +723,7 @@ def async_cleanup(
@callback
def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> None:
def async_setup_cleanup(hass: HomeAssistant, dev_reg: DeviceRegistry) -> None:
"""Clean up device registry when entities removed."""
from . import entity_registry # pylint: disable=import-outside-toplevel

View File

@ -2,20 +2,18 @@
import logging
from typing import Any, Callable
from homeassistant.core import HassJob, callback
from homeassistant.core import HassJob, HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.logging import catch_log_exception
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__)
DATA_DISPATCHER = "dispatcher"
@bind_hass
def dispatcher_connect(
hass: HomeAssistantType, signal: str, target: Callable[..., None]
hass: HomeAssistant, signal: str, target: Callable[..., None]
) -> Callable[[], None]:
"""Connect a callable function to a signal."""
async_unsub = run_callback_threadsafe(
@ -32,7 +30,7 @@ def dispatcher_connect(
@callback
@bind_hass
def async_dispatcher_connect(
hass: HomeAssistantType, signal: str, target: Callable[..., Any]
hass: HomeAssistant, signal: str, target: Callable[..., Any]
) -> Callable[[], None]:
"""Connect a callable function to a signal.
@ -69,14 +67,14 @@ def async_dispatcher_connect(
@bind_hass
def dispatcher_send(hass: HomeAssistantType, signal: str, *args: Any) -> None:
def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
"""Send signal and data."""
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)
@callback
@bind_hass
def async_dispatcher_send(hass: HomeAssistantType, signal: str, *args: Any) -> None:
def async_dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
"""Send signal and data.
This method must be run in the event loop.

View File

@ -12,6 +12,7 @@ from homeassistant import config_entries
from homeassistant.const import ATTR_RESTORED, DEVICE_DEFAULT_NAME
from homeassistant.core import (
CALLBACK_TYPE,
HomeAssistant,
ServiceCall,
callback,
split_entity_id,
@ -24,7 +25,6 @@ from homeassistant.helpers import (
entity_registry as ent_reg,
service,
)
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.util.async_ import run_callback_threadsafe
from .entity_registry import DISABLED_INTEGRATION
@ -50,7 +50,7 @@ class EntityPlatform:
def __init__(
self,
*,
hass: HomeAssistantType,
hass: HomeAssistant,
logger: Logger,
domain: str,
platform_name: str,
@ -633,7 +633,7 @@ current_platform: ContextVar[EntityPlatform | None] = ContextVar(
@callback
def async_get_platforms(
hass: HomeAssistantType, integration_name: str
hass: HomeAssistant, integration_name: str
) -> list[EntityPlatform]:
"""Find existing platforms."""
if (

View File

@ -25,14 +25,20 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_START,
STATE_UNAVAILABLE,
)
from homeassistant.core import Event, callback, split_entity_id, valid_entity_id
from homeassistant.core import (
Event,
HomeAssistant,
callback,
split_entity_id,
valid_entity_id,
)
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.loader import bind_hass
from homeassistant.util import slugify
from homeassistant.util.yaml import load_yaml
from .typing import UNDEFINED, HomeAssistantType, UndefinedType
from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry
@ -109,7 +115,7 @@ class RegistryEntry:
return self.disabled_by is not None
@callback
def write_unavailable_state(self, hass: HomeAssistantType) -> None:
def write_unavailable_state(self, hass: HomeAssistant) -> None:
"""Write the unavailable state to the state machine."""
attrs: dict[str, Any] = {ATTR_RESTORED: True}
@ -139,7 +145,7 @@ class RegistryEntry:
class EntityRegistry:
"""Class to hold a registry of entities."""
def __init__(self, hass: HomeAssistantType):
def __init__(self, hass: HomeAssistant):
"""Initialize the registry."""
self.hass = hass
self.entities: dict[str, RegistryEntry]
@ -572,12 +578,12 @@ class EntityRegistry:
@callback
def async_get(hass: HomeAssistantType) -> EntityRegistry:
def async_get(hass: HomeAssistant) -> EntityRegistry:
"""Get entity registry."""
return cast(EntityRegistry, hass.data[DATA_REGISTRY])
async def async_load(hass: HomeAssistantType) -> None:
async def async_load(hass: HomeAssistant) -> None:
"""Load entity registry."""
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = EntityRegistry(hass)
@ -585,7 +591,7 @@ async def async_load(hass: HomeAssistantType) -> None:
@bind_hass
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
async def async_get_registry(hass: HomeAssistant) -> EntityRegistry:
"""Get entity registry.
This is deprecated and will be removed in the future. Use async_get instead.
@ -666,9 +672,7 @@ async def _async_migrate(entities: dict[str, Any]) -> dict[str, list[dict[str, A
@callback
def async_setup_entity_restore(
hass: HomeAssistantType, registry: EntityRegistry
) -> None:
def async_setup_entity_restore(hass: HomeAssistant, registry: EntityRegistry) -> None:
"""Set up the entity restore mechanism."""
@callback
@ -710,7 +714,7 @@ def async_setup_entity_restore(
async def async_migrate_entries(
hass: HomeAssistantType,
hass: HomeAssistant,
config_entry_id: str,
entry_callback: Callable[[RegistryEntry], dict | None],
) -> None:

View File

@ -7,9 +7,8 @@ from typing import Any, Callable
import httpx
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.core import Event, callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.frame import warn_use
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass
DATA_ASYNC_CLIENT = "httpx_async_client"
@ -22,9 +21,7 @@ USER_AGENT = "User-Agent"
@callback
@bind_hass
def get_async_client(
hass: HomeAssistantType, verify_ssl: bool = True
) -> httpx.AsyncClient:
def get_async_client(hass: HomeAssistant, verify_ssl: bool = True) -> httpx.AsyncClient:
"""Return default httpx AsyncClient.
This method must be run in the event loop.
@ -52,7 +49,7 @@ class HassHttpXAsyncClient(httpx.AsyncClient):
@callback
def create_async_httpx_client(
hass: HomeAssistantType,
hass: HomeAssistant,
verify_ssl: bool = True,
auto_cleanup: bool = True,
**kwargs: Any,
@ -84,7 +81,7 @@ def create_async_httpx_client(
@callback
def _async_register_async_client_shutdown(
hass: HomeAssistantType,
hass: HomeAssistant,
client: httpx.AsyncClient,
original_aclose: Callable[..., Any],
) -> None:

View File

@ -8,10 +8,9 @@ from typing import Any, Callable, Dict, Iterable
import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES
from homeassistant.core import Context, State, T, callback
from homeassistant.core import Context, HomeAssistant, State, T, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass
_LOGGER = logging.getLogger(__name__)
@ -31,7 +30,7 @@ SPEECH_TYPE_SSML = "ssml"
@callback
@bind_hass
def async_register(hass: HomeAssistantType, handler: IntentHandler) -> None:
def async_register(hass: HomeAssistant, handler: IntentHandler) -> None:
"""Register an intent with Home Assistant."""
intents = hass.data.get(DATA_KEY)
if intents is None:
@ -49,7 +48,7 @@ def async_register(hass: HomeAssistantType, handler: IntentHandler) -> None:
@bind_hass
async def async_handle(
hass: HomeAssistantType,
hass: HomeAssistant,
platform: str,
intent_type: str,
slots: _SlotsType | None = None,
@ -103,7 +102,7 @@ class IntentUnexpectedError(IntentError):
@callback
@bind_hass
def async_match_state(
hass: HomeAssistantType, name: str, states: Iterable[State] | None = None
hass: HomeAssistant, name: str, states: Iterable[State] | None = None
) -> State:
"""Find a state that matches the name."""
if states is None:
@ -222,7 +221,7 @@ class Intent:
def __init__(
self,
hass: HomeAssistantType,
hass: HomeAssistant,
platform: str,
intent_type: str,
slots: _SlotsType,

View File

@ -7,8 +7,7 @@ from typing import Sequence
import voluptuous as vol
from homeassistant.const import ATTR_LATITUDE, ATTR_LONGITUDE
from homeassistant.core import State
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.core import HomeAssistant, State
from homeassistant.util import location as loc_util
_LOGGER = logging.getLogger(__name__)
@ -49,7 +48,7 @@ def closest(latitude: float, longitude: float, states: Sequence[State]) -> State
def find_coordinates(
hass: HomeAssistantType, entity_id: str, recursion_history: list | None = None
hass: HomeAssistant, entity_id: str, recursion_history: list | None = None
) -> str | None:
"""Find the gps coordinates of the entity in the form of '90.000,180.000'."""
entity_state = hass.states.get(entity_id)

View File

@ -7,11 +7,11 @@ from typing import Iterable
from homeassistant import config as conf_util
from homeassistant.const import SERVICE_RELOAD
from homeassistant.core import Event, callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform
from homeassistant.helpers.entity_platform import EntityPlatform, async_get_platforms
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component
@ -19,7 +19,7 @@ _LOGGER = logging.getLogger(__name__)
async def async_reload_integration_platforms(
hass: HomeAssistantType, integration_name: str, integration_platforms: Iterable
hass: HomeAssistant, integration_name: str, integration_platforms: Iterable
) -> None:
"""Reload an integration's platforms.
@ -47,7 +47,7 @@ async def async_reload_integration_platforms(
async def _resetup_platform(
hass: HomeAssistantType,
hass: HomeAssistant,
integration_name: str,
integration_platform: str,
unprocessed_conf: ConfigType,
@ -99,7 +99,7 @@ async def _resetup_platform(
async def _async_setup_platform(
hass: HomeAssistantType,
hass: HomeAssistant,
integration_name: str,
integration_platform: str,
platform_configs: list[dict],
@ -129,7 +129,7 @@ async def _async_reconfig_platform(
async def async_integration_yaml_config(
hass: HomeAssistantType, integration_name: str
hass: HomeAssistant, integration_name: str
) -> ConfigType | None:
"""Fetch the latest yaml configuration for an integration."""
integration = await async_get_integration(hass, integration_name)
@ -141,7 +141,7 @@ async def async_integration_yaml_config(
@callback
def async_get_platform_without_config_entry(
hass: HomeAssistantType, integration_name: str, integration_platform_name: str
hass: HomeAssistant, integration_name: str, integration_platform_name: str
) -> EntityPlatform | None:
"""Find an existing platform that is not a config entry."""
for integration_platform in async_get_platforms(hass, integration_name):
@ -155,7 +155,7 @@ def async_get_platform_without_config_entry(
async def async_setup_reload_service(
hass: HomeAssistantType, domain: str, platforms: Iterable
hass: HomeAssistant, domain: str, platforms: Iterable
) -> None:
"""Create the reload service for the domain."""
if hass.services.has_service(domain, SERVICE_RELOAD):
@ -171,9 +171,7 @@ async def async_setup_reload_service(
)
def setup_reload_service(
hass: HomeAssistantType, domain: str, platforms: Iterable
) -> None:
def setup_reload_service(hass: HomeAssistant, domain: str, platforms: Iterable) -> None:
"""Sync version of async_setup_reload_service."""
asyncio.run_coroutine_threadsafe(
async_setup_reload_service(hass, domain, platforms),

View File

@ -22,7 +22,7 @@ from homeassistant.const import (
ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE,
)
import homeassistant.core as ha
from homeassistant.core import Context, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import (
HomeAssistantError,
TemplateError,
@ -36,7 +36,7 @@ from homeassistant.helpers import (
entity_registry,
template,
)
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, TemplateVarsType
from homeassistant.helpers.typing import ConfigType, TemplateVarsType
from homeassistant.loader import (
MAX_LOAD_CONCURRENTLY,
Integration,
@ -72,7 +72,7 @@ class ServiceParams(TypedDict):
class ServiceTargetSelector:
"""Class to hold a target selector for a service."""
def __init__(self, service_call: ha.ServiceCall):
def __init__(self, service_call: ServiceCall):
"""Extract ids from service call data."""
entity_ids: str | list | None = service_call.data.get(ATTR_ENTITY_ID)
device_ids: str | list | None = service_call.data.get(ATTR_DEVICE_ID)
@ -129,7 +129,7 @@ class SelectedEntities:
@bind_hass
def call_from_config(
hass: HomeAssistantType,
hass: HomeAssistant,
config: ConfigType,
blocking: bool = False,
variables: TemplateVarsType = None,
@ -144,12 +144,12 @@ def call_from_config(
@bind_hass
async def async_call_from_config(
hass: HomeAssistantType,
hass: HomeAssistant,
config: ConfigType,
blocking: bool = False,
variables: TemplateVarsType = None,
validate_config: bool = True,
context: ha.Context | None = None,
context: Context | None = None,
) -> None:
"""Call a service based on a config hash."""
try:
@ -164,10 +164,10 @@ async def async_call_from_config(
await hass.services.async_call(**params, blocking=blocking, context=context)
@ha.callback
@callback
@bind_hass
def async_prepare_call_from_config(
hass: HomeAssistantType,
hass: HomeAssistant,
config: ConfigType,
variables: TemplateVarsType = None,
validate_config: bool = False,
@ -246,7 +246,7 @@ def async_prepare_call_from_config(
@bind_hass
def extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> set[str]:
"""Extract a list of entity ids from a service call.
@ -259,9 +259,9 @@ def extract_entity_ids(
@bind_hass
async def async_extract_entities(
hass: HomeAssistantType,
hass: HomeAssistant,
entities: Iterable[Entity],
service_call: ha.ServiceCall,
service_call: ServiceCall,
expand_group: bool = True,
) -> list[Entity]:
"""Extract a list of entity objects from a service call.
@ -298,7 +298,7 @@ async def async_extract_entities(
@bind_hass
async def async_extract_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> set[str]:
"""Extract a set of entity ids from a service call.
@ -317,7 +317,7 @@ def _has_match(ids: str | list | None) -> bool:
@bind_hass
async def async_extract_referenced_entity_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> SelectedEntities:
"""Extract referenced entity IDs from a service call."""
selector = ServiceTargetSelector(service_call)
@ -367,7 +367,7 @@ async def async_extract_referenced_entity_ids(
@bind_hass
async def async_extract_config_entry_ids(
hass: HomeAssistantType, service_call: ha.ServiceCall, expand_group: bool = True
hass: HomeAssistant, service_call: ServiceCall, expand_group: bool = True
) -> set:
"""Extract referenced config entry ids from a service call."""
referenced = await async_extract_referenced_entity_ids(
@ -392,7 +392,7 @@ async def async_extract_config_entry_ids(
return config_entry_ids
def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JSON_TYPE:
def _load_services_file(hass: HomeAssistant, integration: Integration) -> JSON_TYPE:
"""Load services file for an integration."""
try:
return load_yaml(str(integration.file_path / "services.yaml"))
@ -409,7 +409,7 @@ def _load_services_file(hass: HomeAssistantType, integration: Integration) -> JS
def _load_services_files(
hass: HomeAssistantType, integrations: Iterable[Integration]
hass: HomeAssistant, integrations: Iterable[Integration]
) -> list[JSON_TYPE]:
"""Load service files for multiple intergrations."""
return [_load_services_file(hass, integration) for integration in integrations]
@ -417,7 +417,7 @@ def _load_services_files(
@bind_hass
async def async_get_all_descriptions(
hass: HomeAssistantType,
hass: HomeAssistant,
) -> dict[str, dict[str, Any]]:
"""Return descriptions (i.e. user documentation) for all service calls."""
descriptions_cache = hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
@ -482,10 +482,10 @@ async def async_get_all_descriptions(
return descriptions
@ha.callback
@callback
@bind_hass
def async_set_service_schema(
hass: HomeAssistantType, domain: str, service: str, schema: dict[str, Any]
hass: HomeAssistant, domain: str, service: str, schema: dict[str, Any]
) -> None:
"""Register a description for a service."""
hass.data.setdefault(SERVICE_DESCRIPTION_CACHE, {})
@ -504,10 +504,10 @@ def async_set_service_schema(
@bind_hass
async def entity_service_call(
hass: HomeAssistantType,
hass: HomeAssistant,
platforms: Iterable[EntityPlatform],
func: str | Callable[..., Any],
call: ha.ServiceCall,
call: ServiceCall,
required_features: Iterable[int] | None = None,
) -> None:
"""Handle an entity service call.
@ -536,7 +536,7 @@ async def entity_service_call(
# If the service function is a string, we'll pass it the service call data
if isinstance(func, str):
data: dict | ha.ServiceCall = {
data: dict | ServiceCall = {
key: val
for key, val in call.data.items()
if key not in cv.ENTITY_SERVICE_FIELDS
@ -662,11 +662,11 @@ async def entity_service_call(
async def _handle_entity_call(
hass: HomeAssistantType,
hass: HomeAssistant,
entity: Entity,
func: str | Callable[..., Any],
data: dict | ha.ServiceCall,
context: ha.Context,
data: dict | ServiceCall,
context: Context,
) -> None:
"""Handle calling service method."""
entity.async_set_context(context)
@ -690,18 +690,18 @@ async def _handle_entity_call(
@bind_hass
@ha.callback
@callback
def async_register_admin_service(
hass: HomeAssistantType,
hass: HomeAssistant,
domain: str,
service: str,
service_func: Callable[[ha.ServiceCall], Awaitable | None],
service_func: Callable[[ServiceCall], Awaitable | None],
schema: vol.Schema = vol.Schema({}, extra=vol.PREVENT_EXTRA),
) -> None:
"""Register a service that requires admin access."""
@wraps(service_func)
async def admin_handler(call: ha.ServiceCall) -> None:
async def admin_handler(call: ServiceCall) -> None:
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
@ -717,20 +717,20 @@ def async_register_admin_service(
@bind_hass
@ha.callback
@callback
def verify_domain_control(
hass: HomeAssistantType, domain: str
) -> Callable[[Callable[[ha.ServiceCall], Any]], Callable[[ha.ServiceCall], Any]]:
hass: HomeAssistant, domain: str
) -> Callable[[Callable[[ServiceCall], Any]], Callable[[ServiceCall], Any]]:
"""Ensure permission to access any entity under domain in service call."""
def decorator(
service_handler: Callable[[ha.ServiceCall], Any]
) -> Callable[[ha.ServiceCall], Any]:
service_handler: Callable[[ServiceCall], Any]
) -> Callable[[ServiceCall], Any]:
"""Decorate."""
if not asyncio.iscoroutinefunction(service_handler):
raise HomeAssistantError("Can only decorate async functions.")
async def check_permissions(call: ha.ServiceCall) -> Any:
async def check_permissions(call: ServiceCall) -> Any:
"""Check user permission and raise before call if unauthorized."""
if not call.context.user_id:
return await service_handler(call)

View File

@ -20,12 +20,11 @@ from homeassistant.const import (
STATE_UNKNOWN,
STATE_UNLOCKED,
)
from homeassistant.core import Context, State
from homeassistant.core import Context, HomeAssistant, State
from homeassistant.loader import IntegrationNotFound, async_get_integration, bind_hass
import homeassistant.util.dt as dt_util
from .frame import report
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__)
@ -43,7 +42,7 @@ class AsyncTrackStates:
Warning added via `get_changed_since`.
"""
def __init__(self, hass: HomeAssistantType) -> None:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a TrackStates block."""
self.hass = hass
self.states: list[State] = []
@ -77,7 +76,7 @@ def get_changed_since(
@bind_hass
async def async_reproduce_state(
hass: HomeAssistantType,
hass: HomeAssistant,
states: State | Iterable[State],
*,
context: Context | None = None,

View File

@ -5,12 +5,10 @@ import datetime
from typing import TYPE_CHECKING
from homeassistant.const import SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
from .typing import HomeAssistantType
if TYPE_CHECKING:
import astral
@ -19,7 +17,7 @@ DATA_LOCATION_CACHE = "astral_location_cache"
@callback
@bind_hass
def get_astral_location(hass: HomeAssistantType) -> astral.Location:
def get_astral_location(hass: HomeAssistant) -> astral.Location:
"""Get an astral location for the current Home Assistant configuration."""
from astral import Location # pylint: disable=import-outside-toplevel
@ -42,7 +40,7 @@ def get_astral_location(hass: HomeAssistantType) -> astral.Location:
@callback
@bind_hass
def get_astral_event_next(
hass: HomeAssistantType,
hass: HomeAssistant,
event: str,
utc_point_in_time: datetime.datetime | None = None,
offset: datetime.timedelta | None = None,
@ -89,7 +87,7 @@ def get_location_astral_event_next(
@callback
@bind_hass
def get_astral_event_date(
hass: HomeAssistantType,
hass: HomeAssistant,
event: str,
date: datetime.date | datetime.datetime | None = None,
) -> datetime.datetime | None:
@ -114,7 +112,7 @@ def get_astral_event_date(
@callback
@bind_hass
def is_up(
hass: HomeAssistantType, utc_point_in_time: datetime.datetime | None = None
hass: HomeAssistant, utc_point_in_time: datetime.datetime | None = None
) -> bool:
"""Calculate if the sun is currently up."""
if utc_point_in_time is None:

View File

@ -6,14 +6,13 @@ import platform
from typing import Any
from homeassistant.const import __version__ as current_version
from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
from homeassistant.util.package import is_virtual_env
from .typing import HomeAssistantType
@bind_hass
async def async_get_system_info(hass: HomeAssistantType) -> dict[str, Any]:
async def async_get_system_info(hass: HomeAssistant) -> dict[str, Any]:
"""Return info about the system."""
info_object = {
"installation_type": "Unknown",

View File

@ -32,10 +32,16 @@ from homeassistant.const import (
LENGTH_METERS,
STATE_UNKNOWN,
)
from homeassistant.core import State, callback, split_entity_id, valid_entity_id
from homeassistant.core import (
HomeAssistant,
State,
callback,
split_entity_id,
valid_entity_id,
)
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import entity_registry, location as loc_helper
from homeassistant.helpers.typing import HomeAssistantType, TemplateVarsType
from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass
from homeassistant.util import convert, dt as dt_util, location as loc_util
from homeassistant.util.async_ import run_callback_threadsafe
@ -75,7 +81,7 @@ DOMAIN_STATES_RATE_LIMIT = timedelta(seconds=1)
@bind_hass
def attach(hass: HomeAssistantType, obj: Any) -> None:
def attach(hass: HomeAssistant, obj: Any) -> None:
"""Recursively attach hass to all template instances in list and dict."""
if isinstance(obj, list):
for child in obj:
@ -568,7 +574,7 @@ class Template:
class AllStates:
"""Class to expose all HA states as attributes."""
def __init__(self, hass: HomeAssistantType) -> None:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize all states."""
self._hass = hass
@ -622,7 +628,7 @@ class AllStates:
class DomainStates:
"""Class to expose a specific HA domain as attributes."""
def __init__(self, hass: HomeAssistantType, domain: str) -> None:
def __init__(self, hass: HomeAssistant, domain: str) -> None:
"""Initialize the domain states."""
self._hass = hass
self._domain = domain
@ -667,9 +673,7 @@ class TemplateState(State):
# Inheritance is done so functions that check against State keep working
# pylint: disable=super-init-not-called
def __init__(
self, hass: HomeAssistantType, state: State, collect: bool = True
) -> None:
def __init__(self, hass: HomeAssistant, state: State, collect: bool = True) -> None:
"""Initialize template state."""
self._hass = hass
self._state = state
@ -767,33 +771,31 @@ class TemplateState(State):
return f"<template TemplateState({self._state.__repr__()})>"
def _collect_state(hass: HomeAssistantType, entity_id: str) -> None:
def _collect_state(hass: HomeAssistant, entity_id: str) -> None:
entity_collect = hass.data.get(_RENDER_INFO)
if entity_collect is not None:
entity_collect.entities.add(entity_id)
def _state_generator(hass: HomeAssistantType, domain: str | None) -> Generator:
def _state_generator(hass: HomeAssistant, domain: str | None) -> Generator:
"""State generator for a domain or all states."""
for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")):
yield TemplateState(hass, state, collect=False)
def _get_state_if_valid(
hass: HomeAssistantType, entity_id: str
) -> TemplateState | None:
def _get_state_if_valid(hass: HomeAssistant, entity_id: str) -> TemplateState | None:
state = hass.states.get(entity_id)
if state is None and not valid_entity_id(entity_id):
raise TemplateError(f"Invalid entity ID '{entity_id}'") # type: ignore
return _get_template_state_from_state(hass, entity_id, state)
def _get_state(hass: HomeAssistantType, entity_id: str) -> TemplateState | None:
def _get_state(hass: HomeAssistant, entity_id: str) -> TemplateState | None:
return _get_template_state_from_state(hass, entity_id, hass.states.get(entity_id))
def _get_template_state_from_state(
hass: HomeAssistantType, entity_id: str, state: State | None
hass: HomeAssistant, entity_id: str, state: State | None
) -> TemplateState | None:
if state is None:
# Only need to collect if none, if not none collect first actual
@ -804,7 +806,7 @@ def _get_template_state_from_state(
def _resolve_state(
hass: HomeAssistantType, entity_id_or_state: Any
hass: HomeAssistant, entity_id_or_state: Any
) -> State | TemplateState | None:
"""Return state or entity_id if given."""
if isinstance(entity_id_or_state, State):
@ -832,7 +834,7 @@ def result_as_boolean(template_result: str | None) -> bool:
return False
def expand(hass: HomeAssistantType, *args: Any) -> Iterable[State]:
def expand(hass: HomeAssistant, *args: Any) -> Iterable[State]:
"""Expand out any groups into entity states."""
search = list(args)
found = {}
@ -864,7 +866,7 @@ def expand(hass: HomeAssistantType, *args: Any) -> Iterable[State]:
return sorted(found.values(), key=lambda a: a.entity_id)
def device_entities(hass: HomeAssistantType, device_id: str) -> Iterable[str]:
def device_entities(hass: HomeAssistant, device_id: str) -> Iterable[str]:
"""Get entity ids for entities tied to a device."""
entity_reg = entity_registry.async_get(hass)
entries = entity_registry.async_entries_for_device(entity_reg, device_id)
@ -998,7 +1000,7 @@ def distance(hass, *args):
)
def is_state(hass: HomeAssistantType, entity_id: str, state: State) -> bool:
def is_state(hass: HomeAssistant, entity_id: str, state: State) -> bool:
"""Test if a state is a specific value."""
state_obj = _get_state(hass, entity_id)
return state_obj is not None and state_obj.state == state

View File

@ -6,7 +6,7 @@ from collections import ChainMap
import logging
from typing import Any
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.loader import (
MAX_LOAD_CONCURRENTLY,
Integration,
@ -17,8 +17,6 @@ from homeassistant.loader import (
from homeassistant.util.async_ import gather_with_concurrency
from homeassistant.util.json import load_json
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__)
TRANSLATION_LOAD_LOCK = "translation_load_lock"
@ -148,7 +146,7 @@ def _build_resources(
async def async_get_component_strings(
hass: HomeAssistantType, language: str, components: set[str]
hass: HomeAssistant, language: str, components: set[str]
) -> dict[str, Any]:
"""Load translations."""
domains = list({loaded.split(".")[-1] for loaded in components})
@ -204,7 +202,7 @@ async def async_get_component_strings(
class _TranslationCache:
"""Cache for flattened translations."""
def __init__(self, hass: HomeAssistantType) -> None:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the cache."""
self.hass = hass
self.loaded: dict[str, set[str]] = {}
@ -282,7 +280,7 @@ class _TranslationCache:
@bind_hass
async def async_get_translations(
hass: HomeAssistantType,
hass: HomeAssistant,
language: str,
category: str,
integration: str | None = None,

View File

@ -9,9 +9,9 @@ from typing import Any, Callable
import voluptuous as vol
from homeassistant.const import CONF_PLATFORM
from homeassistant.core import CALLBACK_TYPE, callback
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import IntegrationNotFound, async_get_integration
_PLATFORM_ALIASES = {
@ -20,9 +20,7 @@ _PLATFORM_ALIASES = {
}
async def _async_get_trigger_platform(
hass: HomeAssistantType, config: ConfigType
) -> Any:
async def _async_get_trigger_platform(hass: HomeAssistant, config: ConfigType) -> Any:
platform = config[CONF_PLATFORM]
for alias, triggers in _PLATFORM_ALIASES.items():
if platform in triggers:
@ -41,7 +39,7 @@ async def _async_get_trigger_platform(
async def async_validate_trigger_config(
hass: HomeAssistantType, trigger_config: list[ConfigType]
hass: HomeAssistant, trigger_config: list[ConfigType]
) -> list[ConfigType]:
"""Validate triggers."""
config = []
@ -56,7 +54,7 @@ async def async_validate_trigger_config(
async def async_initialize_triggers(
hass: HomeAssistantType,
hass: HomeAssistant,
trigger_config: list[ConfigType],
action: Callable,
domain: str,