Add typing to homeassistant/*.py and homeassistant/util/ (#15569)

* Add typing to homeassistant/*.py and homeassistant/util/

* Fix wrong merge

* Restore iterable in OrderedSet

* Fix tests
This commit is contained in:
Andrey 2018-07-23 11:24:39 +03:00 committed by Paulus Schoutsen
parent b7c336a687
commit 140a874917
27 changed files with 532 additions and 384 deletions

View File

@ -20,7 +20,7 @@ from homeassistant.const import (
)
def attempt_use_uvloop():
def attempt_use_uvloop() -> None:
"""Attempt to use uvloop."""
import asyncio
@ -280,8 +280,8 @@ def setup_and_run_hass(config_dir: str,
# Imported here to avoid importing asyncio before monkey patch
from homeassistant.util.async_ import run_callback_threadsafe
def open_browser(event):
"""Open the webinterface in a browser."""
def open_browser(_: Any) -> None:
"""Open the web interface in a browser."""
if hass.config.api is not None: # type: ignore
import webbrowser
webbrowser.open(hass.config.api.base_url) # type: ignore

View File

@ -221,8 +221,8 @@ async def async_from_config_file(config_path: str,
@core.callback
def async_enable_logging(hass: core.HomeAssistant,
verbose: bool = False,
log_rotate_days=None,
log_file=None,
log_rotate_days: Optional[int] = None,
log_file: Optional[str] = None,
log_no_color: bool = False) -> None:
"""Set up the logging.
@ -291,7 +291,7 @@ def async_enable_logging(hass: core.HomeAssistant,
async_handler = AsyncHandler(hass.loop, err_handler)
async def async_stop_async_handler(event):
async def async_stop_async_handler(_: Any) -> None:
"""Cleanup async handler."""
logging.getLogger('').removeHandler(async_handler) # type: ignore
await async_handler.async_close(blocking=True)

View File

@ -9,7 +9,7 @@ import logging
from aiohttp import web
import voluptuous as vol
from typing import Optional
from homeassistant.auth.util import generate_secret
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API
@ -241,7 +241,7 @@ class RachioIro:
# Only enabled zones
return [z for z in self._zones if z[KEY_ENABLED]]
def get_zone(self, zone_id) -> dict or None:
def get_zone(self, zone_id) -> Optional[dict]:
"""Return the zone with the given ID."""
for zone in self.list_zones(include_disabled=True):
if zone[KEY_ID] == zone_id:

View File

@ -7,8 +7,9 @@ import os
import re
import shutil
# pylint: disable=unused-import
from typing import Any, Tuple, Optional # noqa: F401
from typing import ( # noqa: F401
Any, Tuple, Optional, Dict, List, Union, Callable)
from types import ModuleType
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -21,7 +22,7 @@ from homeassistant.const import (
CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS,
__version__, CONF_CUSTOMIZE, CONF_CUSTOMIZE_DOMAIN, CONF_CUSTOMIZE_GLOB,
CONF_WHITELIST_EXTERNAL_DIRS, CONF_AUTH_PROVIDERS, CONF_TYPE)
from homeassistant.core import callback, DOMAIN as CONF_CORE
from homeassistant.core import callback, DOMAIN as CONF_CORE, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import get_component, get_platform
from homeassistant.util.yaml import load_yaml, SECRET_YAML
@ -193,7 +194,7 @@ def ensure_config_exists(config_dir: str, detect_location: bool = True)\
return config_path
def create_default_config(config_dir: str, detect_location=True)\
def create_default_config(config_dir: str, detect_location: bool = True)\
-> Optional[str]:
"""Create a default configuration file in given configuration directory.
@ -276,7 +277,7 @@ def create_default_config(config_dir: str, detect_location=True)\
return None
async def async_hass_config_yaml(hass):
async def async_hass_config_yaml(hass: HomeAssistant) -> Dict:
"""Load YAML from a Home Assistant configuration file.
This function allow a component inside the asyncio loop to reload its
@ -284,23 +285,26 @@ async def async_hass_config_yaml(hass):
This method is a coroutine.
"""
def _load_hass_yaml_config():
def _load_hass_yaml_config() -> Dict:
path = find_config_file(hass.config.config_dir)
conf = load_yaml_config_file(path)
return conf
if path is None:
raise HomeAssistantError(
"Config file not found in: {}".format(hass.config.config_dir))
return load_yaml_config_file(path)
conf = await hass.async_add_job(_load_hass_yaml_config)
return conf
return await hass.async_add_executor_job(_load_hass_yaml_config)
def find_config_file(config_dir: str) -> Optional[str]:
def find_config_file(config_dir: Optional[str]) -> Optional[str]:
"""Look in given directory for supported configuration files."""
if config_dir is None:
return None
config_path = os.path.join(config_dir, YAML_CONFIG_FILE)
return config_path if os.path.isfile(config_path) else None
def load_yaml_config_file(config_path):
def load_yaml_config_file(config_path: str) -> Dict[Any, Any]:
"""Parse a YAML configuration file.
This method needs to run in an executor.
@ -323,7 +327,7 @@ def load_yaml_config_file(config_path):
return conf_dict
def process_ha_config_upgrade(hass):
def process_ha_config_upgrade(hass: HomeAssistant) -> None:
"""Upgrade configuration if necessary.
This method needs to run in an executor.
@ -360,7 +364,8 @@ def process_ha_config_upgrade(hass):
@callback
def async_log_exception(ex, domain, config, hass):
def async_log_exception(ex: vol.Invalid, domain: str, config: Dict,
hass: HomeAssistant) -> None:
"""Log an error for configuration validation.
This method must be run in the event loop.
@ -371,7 +376,7 @@ def async_log_exception(ex, domain, config, hass):
@callback
def _format_config_error(ex, domain, config):
def _format_config_error(ex: vol.Invalid, domain: str, config: Dict) -> str:
"""Generate log exception for configuration validation.
This method must be run in the event loop.
@ -396,7 +401,8 @@ def _format_config_error(ex, domain, config):
return message
async def async_process_ha_core_config(hass, config):
async def async_process_ha_core_config(
hass: HomeAssistant, config: Dict) -> None:
"""Process the [homeassistant] section from the configuration.
This method is a coroutine.
@ -405,12 +411,12 @@ async def async_process_ha_core_config(hass, config):
# Only load auth during startup.
if not hasattr(hass, 'auth'):
hass.auth = await auth.auth_manager_from_config(
hass, config.get(CONF_AUTH_PROVIDERS, []))
setattr(hass, 'auth', await auth.auth_manager_from_config(
hass, config.get(CONF_AUTH_PROVIDERS, [])))
hac = hass.config
def set_time_zone(time_zone_str):
def set_time_zone(time_zone_str: Optional[str]) -> None:
"""Help to set the time zone."""
if time_zone_str is None:
return
@ -430,11 +436,10 @@ async def async_process_ha_core_config(hass, config):
if key in config:
setattr(hac, attr, config[key])
if CONF_TIME_ZONE in config:
set_time_zone(config.get(CONF_TIME_ZONE))
set_time_zone(config.get(CONF_TIME_ZONE))
# Init whitelist external dir
hac.whitelist_external_dirs = set((hass.config.path('www'),))
hac.whitelist_external_dirs = {hass.config.path('www')}
if CONF_WHITELIST_EXTERNAL_DIRS in config:
hac.whitelist_external_dirs.update(
set(config[CONF_WHITELIST_EXTERNAL_DIRS]))
@ -484,12 +489,12 @@ async def async_process_ha_core_config(hass, config):
hac.time_zone, hac.elevation):
return
discovered = []
discovered = [] # type: List[Tuple[str, Any]]
# If we miss some of the needed values, auto detect them
if None in (hac.latitude, hac.longitude, hac.units,
hac.time_zone):
info = await hass.async_add_job(
info = await hass.async_add_executor_job(
loc_util.detect_location_info)
if info is None:
@ -515,7 +520,7 @@ async def async_process_ha_core_config(hass, config):
if hac.elevation is None and hac.latitude is not None and \
hac.longitude is not None:
elevation = await hass.async_add_job(
elevation = await hass.async_add_executor_job(
loc_util.elevation, hac.latitude, hac.longitude)
hac.elevation = elevation
discovered.append(('elevation', elevation))
@ -526,7 +531,8 @@ async def async_process_ha_core_config(hass, config):
", ".join('{}: {}'.format(key, val) for key, val in discovered))
def _log_pkg_error(package, component, config, message):
def _log_pkg_error(
package: str, component: str, config: Dict, message: str) -> None:
"""Log an error while merging packages."""
message = "Package {} setup failed. Component {} {}".format(
package, component, message)
@ -539,12 +545,13 @@ def _log_pkg_error(package, component, config, message):
_LOGGER.error(message)
def _identify_config_schema(module):
def _identify_config_schema(module: ModuleType) -> \
Tuple[Optional[str], Optional[Dict]]:
"""Extract the schema and identify list or dict based."""
try:
schema = module.CONFIG_SCHEMA.schema[module.DOMAIN]
schema = module.CONFIG_SCHEMA.schema[module.DOMAIN] # type: ignore
except (AttributeError, KeyError):
return (None, None)
return None, None
t_schema = str(schema)
if t_schema.startswith('{'):
return ('dict', schema)
@ -553,9 +560,10 @@ def _identify_config_schema(module):
return '', schema
def _recursive_merge(conf, package):
def _recursive_merge(
conf: Dict[str, Any], package: Dict[str, Any]) -> Union[bool, str]:
"""Merge package into conf, recursively."""
error = False
error = False # type: Union[bool, str]
for key, pack_conf in package.items():
if isinstance(pack_conf, dict):
if not pack_conf:
@ -576,8 +584,8 @@ def _recursive_merge(conf, package):
return error
def merge_packages_config(hass, config, packages,
_log_pkg_error=_log_pkg_error):
def merge_packages_config(hass: HomeAssistant, config: Dict, packages: Dict,
_log_pkg_error: Callable = _log_pkg_error) -> Dict:
"""Merge packages into the top-level configuration. Mutate config."""
# pylint: disable=too-many-nested-blocks
PACKAGES_CONFIG_SCHEMA(packages)
@ -641,7 +649,8 @@ def merge_packages_config(hass, config, packages,
@callback
def async_process_component_config(hass, config, domain):
def async_process_component_config(
hass: HomeAssistant, config: Dict, domain: str) -> Optional[Dict]:
"""Check component configuration and return processed configuration.
Returns None on error.
@ -703,14 +712,14 @@ def async_process_component_config(hass, config, domain):
return config
async def async_check_ha_config_file(hass):
async def async_check_ha_config_file(hass: HomeAssistant) -> Optional[str]:
"""Check if Home Assistant configuration file is valid.
This method is a coroutine.
"""
from homeassistant.scripts.check_config import check_ha_config_file
res = await hass.async_add_job(
res = await hass.async_add_executor_job(
check_ha_config_file, hass)
if not res.errors:
@ -719,7 +728,9 @@ async def async_check_ha_config_file(hass):
@callback
def async_notify_setup_error(hass, component, display_link=False):
def async_notify_setup_error(
hass: HomeAssistant, component: str,
display_link: bool = False) -> None:
"""Print a persistent notification.
This method must be run in the event loop.

View File

@ -113,10 +113,10 @@ the flow from the config panel.
import logging
import uuid
from typing import Set # noqa pylint: disable=unused-import
from typing import Set, Optional # noqa pylint: disable=unused-import
from homeassistant import data_entry_flow
from homeassistant.core import callback
from homeassistant.core import callback, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component, async_process_deps_reqs
from homeassistant.util.decorator import Registry
@ -164,8 +164,9 @@ class ConfigEntry:
__slots__ = ('entry_id', 'version', 'domain', 'title', 'data', 'source',
'state')
def __init__(self, version, domain, title, data, source, entry_id=None,
state=ENTRY_STATE_NOT_LOADED):
def __init__(self, version: str, domain: str, title: str, data: dict,
source: str, entry_id: Optional[str] = None,
state: str = ENTRY_STATE_NOT_LOADED) -> None:
"""Initialize a config entry."""
# Unique id of the config entry
self.entry_id = entry_id or uuid.uuid4().hex
@ -188,7 +189,8 @@ class ConfigEntry:
# State of the entry (LOADED, NOT_LOADED)
self.state = state
async def async_setup(self, hass, *, component=None):
async def async_setup(
self, hass: HomeAssistant, *, component=None) -> None:
"""Set up an entry."""
if component is None:
component = getattr(hass.components, self.domain)

View File

@ -4,9 +4,9 @@ Core components of Home Assistant.
Home Assistant is a Home Automation framework for observing the state
of entities and react to changes.
"""
# pylint: disable=unused-import
import asyncio
from concurrent.futures import ThreadPoolExecutor
import datetime
import enum
import logging
import os
@ -17,9 +17,10 @@ import threading
from time import monotonic
from types import MappingProxyType
# pylint: disable=unused-import
from typing import ( # NOQA
Optional, Any, Callable, List, TypeVar, Dict, Coroutine, Set,
TYPE_CHECKING)
TYPE_CHECKING, Awaitable, Iterator)
from async_timeout import timeout
import voluptuous as vol
@ -44,11 +45,13 @@ from homeassistant.util import location
from homeassistant.util.unit_system import UnitSystem, METRIC_SYSTEM # NOQA
# Typing imports that create a circular dependency
# pylint: disable=using-constant-test,unused-import
# pylint: disable=using-constant-test
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntries # noqa
from homeassistant.config_entries import ConfigEntries # noqa
T = TypeVar('T')
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
CALLBACK_TYPE = Callable[[], None]
DOMAIN = 'homeassistant'
@ -79,7 +82,7 @@ def valid_state(state: str) -> bool:
return len(state) < 256
def callback(func: Callable[..., T]) -> Callable[..., T]:
def callback(func: CALLABLE_T) -> CALLABLE_T:
"""Annotation to mark method as safe to call from within the event loop."""
setattr(func, '_hass_callback', True)
return func
@ -91,7 +94,7 @@ def is_callback(func: Callable[..., Any]) -> bool:
@callback
def async_loop_exception_handler(loop, context):
def async_loop_exception_handler(_: Any, context: Dict) -> None:
"""Handle all exception inside the core loop."""
kwargs = {}
exception = context.get('exception')
@ -119,7 +122,9 @@ class CoreState(enum.Enum):
class HomeAssistant:
"""Root object of the Home Assistant home automation."""
def __init__(self, loop=None):
def __init__(
self,
loop: Optional[asyncio.events.AbstractEventLoop] = None) -> None:
"""Initialize new Home Assistant object."""
if sys.platform == 'win32':
self.loop = loop or asyncio.ProactorEventLoop()
@ -170,7 +175,7 @@ class HomeAssistant:
self.loop.close()
return self.exit_code
async def async_start(self):
async def async_start(self) -> None:
"""Finalize startup from inside the event loop.
This method is a coroutine.
@ -178,8 +183,7 @@ class HomeAssistant:
_LOGGER.info("Starting Home Assistant")
self.state = CoreState.starting
# pylint: disable=protected-access
self.loop._thread_ident = threading.get_ident()
setattr(self.loop, '_thread_ident', threading.get_ident())
self.bus.async_fire(EVENT_HOMEASSISTANT_START)
try:
@ -230,7 +234,8 @@ class HomeAssistant:
elif asyncio.iscoroutinefunction(target):
task = self.loop.create_task(target(*args))
else:
task = self.loop.run_in_executor(None, target, *args)
task = self.loop.run_in_executor( # type: ignore
None, target, *args)
# If a task is scheduled
if self._track_task and task is not None:
@ -256,11 +261,11 @@ class HomeAssistant:
@callback
def async_add_executor_job(
self,
target: Callable[..., Any],
*args: Any) -> asyncio.Future:
target: Callable[..., T],
*args: Any) -> Awaitable[T]:
"""Add an executor job from within the event loop."""
task = self.loop.run_in_executor( # type: ignore
None, target, *args) # type: asyncio.Future
task = self.loop.run_in_executor(
None, target, *args)
# If a task is scheduled
if self._track_task:
@ -269,12 +274,12 @@ class HomeAssistant:
return task
@callback
def async_track_tasks(self):
def async_track_tasks(self) -> None:
"""Track tasks so you can wait for all tasks to be done."""
self._track_task = True
@callback
def async_stop_track_tasks(self):
def async_stop_track_tasks(self) -> None:
"""Stop track tasks so you can't wait for all tasks to be done."""
self._track_task = False
@ -297,7 +302,7 @@ class HomeAssistant:
run_coroutine_threadsafe(
self.async_block_till_done(), loop=self.loop).result()
async def async_block_till_done(self):
async def async_block_till_done(self) -> None:
"""Block till all pending work is done."""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0, loop=self.loop)
@ -342,9 +347,9 @@ class EventOrigin(enum.Enum):
local = 'LOCAL'
remote = 'REMOTE'
def __str__(self):
def __str__(self) -> str:
"""Return the event."""
return self.value
return self.value # type: ignore
class Event:
@ -352,15 +357,16 @@ class Event:
__slots__ = ['event_type', 'data', 'origin', 'time_fired']
def __init__(self, event_type, data=None, origin=EventOrigin.local,
time_fired=None):
def __init__(self, event_type: str, data: Optional[Dict] = None,
origin: EventOrigin = EventOrigin.local,
time_fired: Optional[int] = None) -> None:
"""Initialize a new event."""
self.event_type = event_type
self.data = data or {}
self.origin = origin
self.time_fired = time_fired or dt_util.utcnow()
def as_dict(self):
def as_dict(self) -> Dict:
"""Create a dict representation of this Event.
Async friendly.
@ -372,7 +378,7 @@ class Event:
'time_fired': self.time_fired,
}
def __repr__(self):
def __repr__(self) -> str:
"""Return the representation."""
# pylint: disable=maybe-no-member
if self.data:
@ -383,9 +389,9 @@ class Event:
return "<Event {}[{}]>".format(self.event_type,
str(self.origin)[0])
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Return the comparison."""
return (self.__class__ == other.__class__ and
return (self.__class__ == other.__class__ and # type: ignore
self.event_type == other.event_type and
self.data == other.data and
self.origin == other.origin and
@ -401,7 +407,7 @@ class EventBus:
self._hass = hass
@callback
def async_listeners(self):
def async_listeners(self) -> Dict[str, int]:
"""Return dictionary with events and the number of listeners.
This method must be run in the event loop.
@ -410,20 +416,21 @@ class EventBus:
for key in self._listeners}
@property
def listeners(self):
def listeners(self) -> Dict[str, int]:
"""Return dictionary with events and the number of listeners."""
return run_callback_threadsafe(
return run_callback_threadsafe( # type: ignore
self._hass.loop, self.async_listeners
).result()
def fire(self, event_type: str, event_data=None, origin=EventOrigin.local):
def fire(self, event_type: str, event_data: Optional[Dict] = None,
origin: EventOrigin = EventOrigin.local) -> None:
"""Fire an event."""
self._hass.loop.call_soon_threadsafe(
self.async_fire, event_type, event_data, origin)
@callback
def async_fire(self, event_type: str, event_data=None,
origin=EventOrigin.local):
def async_fire(self, event_type: str, event_data: Optional[Dict] = None,
origin: EventOrigin = EventOrigin.local) -> None:
"""Fire an event.
This method must be run in the event loop.
@ -447,7 +454,8 @@ class EventBus:
for func in listeners:
self._hass.async_add_job(func, event)
def listen(self, event_type, listener):
def listen(
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type.
To listen to all events specify the constant ``MATCH_ALL``
@ -456,7 +464,7 @@ class EventBus:
async_remove_listener = run_callback_threadsafe(
self._hass.loop, self.async_listen, event_type, listener).result()
def remove_listener():
def remove_listener() -> None:
"""Remove the listener."""
run_callback_threadsafe(
self._hass.loop, async_remove_listener).result()
@ -464,7 +472,8 @@ class EventBus:
return remove_listener
@callback
def async_listen(self, event_type, listener):
def async_listen(
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type.
To listen to all events specify the constant ``MATCH_ALL``
@ -477,13 +486,14 @@ class EventBus:
else:
self._listeners[event_type] = [listener]
def remove_listener():
def remove_listener() -> None:
"""Remove the listener."""
self._async_remove_listener(event_type, listener)
return remove_listener
def listen_once(self, event_type, listener):
def listen_once(
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen once for event of a specific type.
To listen to all events specify the constant ``MATCH_ALL``
@ -495,7 +505,7 @@ class EventBus:
self._hass.loop, self.async_listen_once, event_type, listener,
).result()
def remove_listener():
def remove_listener() -> None:
"""Remove the listener."""
run_callback_threadsafe(
self._hass.loop, async_remove_listener).result()
@ -503,7 +513,8 @@ class EventBus:
return remove_listener
@callback
def async_listen_once(self, event_type, listener):
def async_listen_once(
self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen once for event of a specific type.
To listen to all events specify the constant ``MATCH_ALL``
@ -514,8 +525,8 @@ class EventBus:
This method must be run in the event loop.
"""
@callback
def onetime_listener(event):
"""Remove listener from eventbus and then fire listener."""
def onetime_listener(event: Event) -> None:
"""Remove listener from event bus and then fire listener."""
if hasattr(onetime_listener, 'run'):
return
# Set variable so that we will never run twice.
@ -530,7 +541,8 @@ class EventBus:
return self.async_listen(event_type, onetime_listener)
@callback
def _async_remove_listener(self, event_type, listener):
def _async_remove_listener(
self, event_type: str, listener: Callable) -> None:
"""Remove a listener of a specific event_type.
This method must be run in the event loop.
@ -560,8 +572,10 @@ class State:
__slots__ = ['entity_id', 'state', 'attributes',
'last_changed', 'last_updated']
def __init__(self, entity_id, state, attributes=None, last_changed=None,
last_updated=None):
def __init__(self, entity_id: str, state: Any,
attributes: Optional[Dict] = None,
last_changed: Optional[datetime.datetime] = None,
last_updated: Optional[datetime.datetime] = None) -> None:
"""Initialize a new state."""
state = str(state)
@ -582,23 +596,23 @@ class State:
self.last_changed = last_changed or self.last_updated
@property
def domain(self):
def domain(self) -> str:
"""Domain of this state."""
return split_entity_id(self.entity_id)[0]
@property
def object_id(self):
def object_id(self) -> str:
"""Object id of this state."""
return split_entity_id(self.entity_id)[1]
@property
def name(self):
def name(self) -> str:
"""Name of this state."""
return (
self.attributes.get(ATTR_FRIENDLY_NAME) or
self.object_id.replace('_', ' '))
def as_dict(self):
def as_dict(self) -> Dict:
"""Return a dict representation of the State.
Async friendly.
@ -613,7 +627,7 @@ class State:
'last_updated': self.last_updated}
@classmethod
def from_dict(cls, json_dict):
def from_dict(cls, json_dict: Dict) -> Any:
"""Initialize a state from a dict.
Async friendly.
@ -637,14 +651,14 @@ class State:
return cls(json_dict['entity_id'], json_dict['state'],
json_dict.get('attributes'), last_changed, last_updated)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Return the comparison of the state."""
return (self.__class__ == other.__class__ and
return (self.__class__ == other.__class__ and # type: ignore
self.entity_id == other.entity_id and
self.state == other.state and
self.attributes == other.attributes)
def __repr__(self):
def __repr__(self) -> str:
"""Return the representation of the states."""
attr = "; {}".format(util.repr_helper(self.attributes)) \
if self.attributes else ""
@ -657,21 +671,23 @@ class State:
class StateMachine:
"""Helper class that tracks the state of different entities."""
def __init__(self, bus, loop):
def __init__(self, bus: EventBus,
loop: asyncio.events.AbstractEventLoop) -> None:
"""Initialize state machine."""
self._states = {} # type: Dict[str, State]
self._bus = bus
self._loop = loop
def entity_ids(self, domain_filter=None):
def entity_ids(self, domain_filter: Optional[str] = None)-> List[str]:
"""List of entity ids that are being tracked."""
future = run_callback_threadsafe(
self._loop, self.async_entity_ids, domain_filter
)
return future.result()
return future.result() # type: ignore
@callback
def async_entity_ids(self, domain_filter=None):
def async_entity_ids(
self, domain_filter: Optional[str] = None) -> List[str]:
"""List of entity ids that are being tracked.
This method must be run in the event loop.
@ -684,26 +700,27 @@ class StateMachine:
return [state.entity_id for state in self._states.values()
if state.domain == domain_filter]
def all(self):
def all(self)-> List[State]:
"""Create a list of all states."""
return run_callback_threadsafe(self._loop, self.async_all).result()
return run_callback_threadsafe( # type: ignore
self._loop, self.async_all).result()
@callback
def async_all(self):
def async_all(self)-> List[State]:
"""Create a list of all states.
This method must be run in the event loop.
"""
return list(self._states.values())
def get(self, entity_id):
def get(self, entity_id: str) -> Optional[State]:
"""Retrieve state of entity_id or None if not found.
Async friendly.
"""
return self._states.get(entity_id.lower())
def is_state(self, entity_id, state):
def is_state(self, entity_id: str, state: State) -> bool:
"""Test if entity exists and is specified state.
Async friendly.
@ -711,16 +728,16 @@ class StateMachine:
state_obj = self.get(entity_id)
return state_obj is not None and state_obj.state == state
def remove(self, entity_id):
def remove(self, entity_id: str) -> bool:
"""Remove the state of an entity.
Returns boolean to indicate if an entity was removed.
"""
return run_callback_threadsafe(
return run_callback_threadsafe( # type: ignore
self._loop, self.async_remove, entity_id).result()
@callback
def async_remove(self, entity_id):
def async_remove(self, entity_id: str) -> bool:
"""Remove the state of an entity.
Returns boolean to indicate if an entity was removed.
@ -740,7 +757,9 @@ class StateMachine:
})
return True
def set(self, entity_id, new_state, attributes=None, force_update=False):
def set(self, entity_id: str, new_state: Any,
attributes: Optional[Dict] = None,
force_update: bool = False) -> None:
"""Set the state of an entity, add entity if it does not exist.
Attributes is an optional dict to specify attributes of this state.
@ -754,8 +773,9 @@ class StateMachine:
).result()
@callback
def async_set(self, entity_id, new_state, attributes=None,
force_update=False):
def async_set(self, entity_id: str, new_state: Any,
attributes: Optional[Dict] = None,
force_update: bool = False) -> None:
"""Set the state of an entity, add entity if it does not exist.
Attributes is an optional dict to specify attributes of this state.
@ -769,15 +789,19 @@ class StateMachine:
new_state = str(new_state)
attributes = attributes or {}
old_state = self._states.get(entity_id)
is_existing = old_state is not None
same_state = (is_existing and old_state.state == new_state and
not force_update)
same_attr = is_existing and old_state.attributes == attributes
if old_state is None:
same_state = False
same_attr = False
last_changed = None
else:
same_state = (old_state.state == new_state and
not force_update)
same_attr = old_state.attributes == attributes
last_changed = old_state.last_changed if same_state else None
if same_state and same_attr:
return
last_changed = old_state.last_changed if same_state else None
state = State(entity_id, new_state, attributes, last_changed)
self._states[entity_id] = state
self._bus.async_fire(EVENT_STATE_CHANGED, {
@ -792,7 +816,7 @@ class Service:
__slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction']
def __init__(self, func, schema):
def __init__(self, func: Callable, schema: Optional[vol.Schema]) -> None:
"""Initialize a service."""
self.func = func
self.schema = schema
@ -805,14 +829,15 @@ class ServiceCall:
__slots__ = ['domain', 'service', 'data', 'call_id']
def __init__(self, domain, service, data=None, call_id=None):
def __init__(self, domain: str, service: str, data: Optional[Dict] = None,
call_id: Optional[str] = None) -> None:
"""Initialize a service call."""
self.domain = domain.lower()
self.service = service.lower()
self.data = MappingProxyType(data or {})
self.call_id = call_id
def __repr__(self):
def __repr__(self) -> str:
"""Return the representation of the service."""
if self.data:
return "<ServiceCall {}.{}: {}>".format(
@ -824,13 +849,13 @@ class ServiceCall:
class ServiceRegistry:
"""Offer the services over the eventbus."""
def __init__(self, hass):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a service registry."""
self._services = {} # type: Dict[str, Dict[str, Service]]
self._hass = hass
self._async_unsub_call_event = None
self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
def _gen_unique_id():
def _gen_unique_id() -> Iterator[str]:
cur_id = 1
while True:
yield '{}-{}'.format(id(self), cur_id)
@ -840,14 +865,14 @@ class ServiceRegistry:
self._generate_unique_id = lambda: next(gen)
@property
def services(self):
def services(self) -> Dict[str, Dict[str, Service]]:
"""Return dictionary with per domain a list of available services."""
return run_callback_threadsafe(
return run_callback_threadsafe( # type: ignore
self._hass.loop, self.async_services,
).result()
@callback
def async_services(self):
def async_services(self) -> Dict[str, Dict[str, Service]]:
"""Return dictionary with per domain a list of available services.
This method must be run in the event loop.
@ -855,14 +880,15 @@ class ServiceRegistry:
return {domain: self._services[domain].copy()
for domain in self._services}
def has_service(self, domain, service):
def has_service(self, domain: str, service: str) -> bool:
"""Test if specified service exists.
Async friendly.
"""
return service.lower() in self._services.get(domain.lower(), [])
def register(self, domain, service, service_func, schema=None):
def register(self, domain: str, service: str, service_func: Callable,
schema: Optional[vol.Schema] = None) -> None:
"""
Register a service.
@ -874,7 +900,8 @@ class ServiceRegistry:
).result()
@callback
def async_register(self, domain, service, service_func, schema=None):
def async_register(self, domain: str, service: str, service_func: Callable,
schema: Optional[vol.Schema] = None) -> None:
"""
Register a service.
@ -900,13 +927,13 @@ class ServiceRegistry:
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
)
def remove(self, domain, service):
def remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler."""
run_callback_threadsafe(
self._hass.loop, self.async_remove, domain, service).result()
@callback
def async_remove(self, domain, service):
def async_remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler.
This method must be run in the event loop.
@ -926,7 +953,9 @@ class ServiceRegistry:
{ATTR_DOMAIN: domain, ATTR_SERVICE: service}
)
def call(self, domain, service, service_data=None, blocking=False):
def call(self, domain: str, service: str,
service_data: Optional[Dict] = None,
blocking: bool = False) -> Optional[bool]:
"""
Call a service.
@ -943,13 +972,14 @@ class ServiceRegistry:
Because the service is sent as an event you are not allowed to use
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
"""
return run_coroutine_threadsafe(
return run_coroutine_threadsafe( # type: ignore
self.async_call(domain, service, service_data, blocking),
self._hass.loop
).result()
async def async_call(self, domain, service, service_data=None,
blocking=False):
async def async_call(self, domain: str, service: str,
service_data: Optional[Dict] = None,
blocking: bool = False) -> Optional[bool]:
"""
Call a service.
@ -981,7 +1011,7 @@ class ServiceRegistry:
fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future
@callback
def service_executed(event):
def service_executed(event: Event) -> None:
"""Handle an executed service."""
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
fut.set_result(True)
@ -989,20 +1019,22 @@ class ServiceRegistry:
unsub = self._hass.bus.async_listen(
EVENT_SERVICE_EXECUTED, service_executed)
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
if blocking:
done, _ = await asyncio.wait(
[fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT)
success = bool(done)
unsub()
return success
async def _event_to_service_call(self, event):
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
return None
async def _event_to_service_call(self, event: Event) -> None:
"""Handle the SERVICE_CALLED events from the EventBus."""
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
domain = event.data.get(ATTR_DOMAIN).lower()
service = event.data.get(ATTR_SERVICE).lower()
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
if not self.has_service(domain, service):
@ -1013,7 +1045,7 @@ class ServiceRegistry:
service_handler = self._services[domain][service]
def fire_service_executed():
def fire_service_executed() -> None:
"""Fire service executed event."""
if not call_id:
return
@ -1045,12 +1077,12 @@ class ServiceRegistry:
await service_handler.func(service_call)
fire_service_executed()
else:
def execute_service():
def execute_service() -> None:
"""Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call)
fire_service_executed()
await self._hass.async_add_job(execute_service)
await self._hass.async_add_executor_job(execute_service)
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error executing service %s', service_call)
@ -1058,13 +1090,13 @@ class ServiceRegistry:
class Config:
"""Configuration settings for Home Assistant."""
def __init__(self):
def __init__(self) -> None:
"""Initialize a new config object."""
self.latitude = None # type: Optional[float]
self.longitude = None # type: Optional[float]
self.elevation = None # type: Optional[int]
self.location_name = None # type: Optional[str]
self.time_zone = None # type: Optional[str]
self.time_zone = None # type: Optional[datetime.tzinfo]
self.units = METRIC_SYSTEM # type: UnitSystem
# If True, pip install is skipped for requirements on startup
@ -1090,7 +1122,7 @@ class Config:
return self.units.length(
location.distance(self.latitude, self.longitude, lat, lon), 'm')
def path(self, *path):
def path(self, *path: str) -> str:
"""Generate path to the file within the configuration directory.
Async friendly.
@ -1122,12 +1154,14 @@ class Config:
return False
def as_dict(self):
def as_dict(self) -> Dict:
"""Create a dictionary representation of this dict.
Async friendly.
"""
time_zone = self.time_zone or dt_util.UTC
time_zone = dt_util.UTC.zone
if self.time_zone and getattr(self.time_zone, 'zone'):
time_zone = getattr(self.time_zone, 'zone')
return {
'latitude': self.latitude,
@ -1135,7 +1169,7 @@ class Config:
'elevation': self.elevation,
'unit_system': self.units.as_dict(),
'location_name': self.location_name,
'time_zone': time_zone.zone,
'time_zone': time_zone,
'components': self.components,
'config_dir': self.config_dir,
'whitelist_external_dirs': self.whitelist_external_dirs,
@ -1143,12 +1177,12 @@ class Config:
}
def _async_create_timer(hass):
def _async_create_timer(hass: HomeAssistant) -> None:
"""Create a timer that will start on HOMEASSISTANT_START."""
handle = None
@callback
def fire_time_event(nxt):
def fire_time_event(nxt: float) -> None:
"""Fire next time event."""
nonlocal handle
@ -1165,7 +1199,7 @@ def _async_create_timer(hass):
handle = hass.loop.call_later(slp_seconds, fire_time_event, nxt)
@callback
def stop_timer(event):
def stop_timer(_: Event) -> None:
"""Stop the timer."""
if handle is not None:
handle.cancel()

View File

@ -1,8 +1,9 @@
"""Classes to help gather user submissions."""
import logging
import uuid
from typing import Dict, Any # noqa pylint: disable=unused-import
from .core import callback
import voluptuous as vol
from typing import Dict, Any, Callable, List, Optional # noqa pylint: disable=unused-import
from .core import callback, HomeAssistant
from .exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__)
@ -35,7 +36,8 @@ class UnknownStep(FlowError):
class FlowManager:
"""Manage all the flows that are in progress."""
def __init__(self, hass, async_create_flow, async_finish_flow):
def __init__(self, hass: HomeAssistant, async_create_flow: Callable,
async_finish_flow: Callable) -> None:
"""Initialize the flow manager."""
self.hass = hass
self._progress = {} # type: Dict[str, Any]
@ -43,7 +45,7 @@ class FlowManager:
self._async_finish_flow = async_finish_flow
@callback
def async_progress(self):
def async_progress(self) -> List[Dict]:
"""Return the flows in progress."""
return [{
'flow_id': flow.flow_id,
@ -51,7 +53,8 @@ class FlowManager:
'source': flow.source,
} for flow in self._progress.values()]
async def async_init(self, handler, *, source=SOURCE_USER, data=None):
async def async_init(self, handler: Callable, *, source: str = SOURCE_USER,
data: str = None) -> Any:
"""Start a configuration flow."""
flow = await self._async_create_flow(handler, source=source, data=data)
flow.hass = self.hass
@ -67,7 +70,8 @@ class FlowManager:
return await self._async_handle_step(flow, step, data)
async def async_configure(self, flow_id, user_input=None):
async def async_configure(
self, flow_id: str, user_input: str = None) -> Any:
"""Continue a configuration flow."""
flow = self._progress.get(flow_id)
@ -83,12 +87,13 @@ class FlowManager:
flow, step_id, user_input)
@callback
def async_abort(self, flow_id):
def async_abort(self, flow_id: str) -> None:
"""Abort a flow."""
if self._progress.pop(flow_id, None) is None:
raise UnknownFlow
async def _async_handle_step(self, flow, step_id, user_input):
async def _async_handle_step(self, flow: Any, step_id: str,
user_input: Optional[str]) -> Dict:
"""Handle a step of a flow."""
method = "async_step_{}".format(step_id)
@ -97,7 +102,7 @@ class FlowManager:
raise UnknownStep("Handler {} doesn't support step {}".format(
flow.__class__.__name__, step_id))
result = await getattr(flow, method)(user_input)
result = await getattr(flow, method)(user_input) # type: Dict
if result['type'] not in (RESULT_TYPE_FORM, RESULT_TYPE_CREATE_ENTRY,
RESULT_TYPE_ABORT):
@ -133,8 +138,9 @@ class FlowHandler:
VERSION = 1
@callback
def async_show_form(self, *, step_id, data_schema=None, errors=None,
description_placeholders=None):
def async_show_form(self, *, step_id: str, data_schema: vol.Schema = None,
errors: Dict = None,
description_placeholders: Dict = None) -> Dict:
"""Return the definition of a form to gather user input."""
return {
'type': RESULT_TYPE_FORM,
@ -147,7 +153,7 @@ class FlowHandler:
}
@callback
def async_create_entry(self, *, title, data):
def async_create_entry(self, *, title: str, data: Dict) -> Dict:
"""Finish config flow and create a config entry."""
return {
'version': self.VERSION,
@ -160,7 +166,7 @@ class FlowHandler:
}
@callback
def async_abort(self, *, reason):
def async_abort(self, *, reason: str) -> Dict:
"""Abort the config flow."""
return {
'type': RESULT_TYPE_ABORT,

View File

@ -17,7 +17,7 @@ import sys
from types import ModuleType
# pylint: disable=unused-import
from typing import Optional, Set, TYPE_CHECKING # NOQA
from typing import Optional, Set, TYPE_CHECKING, Callable, Any, TypeVar # NOQA
from homeassistant.const import PLATFORM_FORMAT
from homeassistant.util import OrderedSet
@ -27,6 +27,8 @@ from homeassistant.util import OrderedSet
if TYPE_CHECKING:
from homeassistant.core import HomeAssistant # NOQA
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
PREPARED = False
DEPENDENCY_BLACKLIST = {'config'}
@ -51,7 +53,8 @@ def set_component(hass, # type: HomeAssistant
cache[comp_name] = component
def get_platform(hass, domain: str, platform: str) -> Optional[ModuleType]:
def get_platform(hass, # type: HomeAssistant
domain: str, platform: str) -> Optional[ModuleType]:
"""Try to load specified platform.
Async friendly.
@ -59,7 +62,8 @@ def get_platform(hass, domain: str, platform: str) -> Optional[ModuleType]:
return get_component(hass, PLATFORM_FORMAT.format(domain, platform))
def get_component(hass, comp_or_platform) -> Optional[ModuleType]:
def get_component(hass, # type: HomeAssistant
comp_or_platform: str) -> Optional[ModuleType]:
"""Try to load specified component.
Looks in config dir first, then built-in components.
@ -73,6 +77,9 @@ def get_component(hass, comp_or_platform) -> Optional[ModuleType]:
cache = hass.data.get(DATA_KEY)
if cache is None:
if hass.config.config_dir is None:
_LOGGER.error("Can't load components - config dir is not set")
return None
# Only insert if it's not there (happens during tests)
if sys.path[0] != hass.config.config_dir:
sys.path.insert(0, hass.config.config_dir)
@ -134,14 +141,38 @@ def get_component(hass, comp_or_platform) -> Optional[ModuleType]:
return None
class ModuleWrapper:
"""Class to wrap a Python module and auto fill in hass argument."""
def __init__(self,
hass, # type: HomeAssistant
module: ModuleType) -> None:
"""Initialize the module wrapper."""
self._hass = hass
self._module = module
def __getattr__(self, attr: str) -> Any:
"""Fetch an attribute."""
value = getattr(self._module, attr)
if hasattr(value, '__bind_hass'):
value = ft.partial(value, self._hass)
setattr(self, attr, value)
return value
class Components:
"""Helper to load components."""
def __init__(self, hass):
def __init__(
self,
hass # type: HomeAssistant
) -> None:
"""Initialize the Components class."""
self._hass = hass
def __getattr__(self, comp_name):
def __getattr__(self, comp_name: str) -> ModuleWrapper:
"""Fetch a component."""
component = get_component(self._hass, comp_name)
if component is None:
@ -154,11 +185,14 @@ class Components:
class Helpers:
"""Helper to load helpers."""
def __init__(self, hass):
def __init__(
self,
hass # type: HomeAssistant
) -> None:
"""Initialize the Helpers class."""
self._hass = hass
def __getattr__(self, helper_name):
def __getattr__(self, helper_name: str) -> ModuleWrapper:
"""Fetch a helper."""
helper = importlib.import_module(
'homeassistant.helpers.{}'.format(helper_name))
@ -167,33 +201,14 @@ class Helpers:
return wrapped
class ModuleWrapper:
"""Class to wrap a Python module and auto fill in hass argument."""
def __init__(self, hass, module):
"""Initialize the module wrapper."""
self._hass = hass
self._module = module
def __getattr__(self, attr):
"""Fetch an attribute."""
value = getattr(self._module, attr)
if hasattr(value, '__bind_hass'):
value = ft.partial(value, self._hass)
setattr(self, attr, value)
return value
def bind_hass(func):
def bind_hass(func: CALLABLE_T) -> CALLABLE_T:
"""Decorate function to indicate that first argument is hass."""
# pylint: disable=protected-access
func.__bind_hass = True
setattr(func, '__bind_hass', True)
return func
def load_order_component(hass, comp_name: str) -> OrderedSet:
def load_order_component(hass, # type: HomeAssistant
comp_name: str) -> OrderedSet:
"""Return an OrderedSet of components in the correct order of loading.
Raises HomeAssistantError if a circular dependency is detected.
@ -204,7 +219,8 @@ def load_order_component(hass, comp_name: str) -> OrderedSet:
return _load_order_component(hass, comp_name, OrderedSet(), set())
def _load_order_component(hass, comp_name: str, load_order: OrderedSet,
def _load_order_component(hass, # type: HomeAssistant
comp_name: str, load_order: OrderedSet,
loading: Set) -> OrderedSet:
"""Recursive function to get load order of components.

View File

@ -20,9 +20,10 @@ Related Python bugs:
- https://bugs.python.org/issue26617
"""
import sys
from typing import Any
def patch_weakref_tasks():
def patch_weakref_tasks() -> None:
"""Replace weakref.WeakSet to address Python 3 bug."""
# pylint: disable=no-self-use, protected-access, bare-except
import asyncio.tasks
@ -30,7 +31,7 @@ def patch_weakref_tasks():
class IgnoreCalls:
"""Ignore add calls."""
def add(self, other):
def add(self, other: Any) -> None:
"""No-op add."""
return
@ -41,7 +42,7 @@ def patch_weakref_tasks():
pass
def disable_c_asyncio():
def disable_c_asyncio() -> None:
"""Disable using C implementation of asyncio.
Required to be able to apply the weakref monkey patch.
@ -53,12 +54,12 @@ def disable_c_asyncio():
PATH_TRIGGER = '_asyncio'
def __init__(self, path_entry):
def __init__(self, path_entry: str) -> None:
if path_entry != self.PATH_TRIGGER:
raise ImportError()
return
def find_module(self, fullname, path=None):
def find_module(self, fullname: str, path: Any = None) -> None:
"""Find a module."""
if fullname == self.PATH_TRIGGER:
# We lint in Py35, exception is introduced in Py36

View File

@ -13,7 +13,7 @@ import json
import logging
import urllib.parse
from typing import Optional
from typing import Optional, Dict, Any, List
from aiohttp.hdrs import METH_GET, METH_POST, METH_DELETE, CONTENT_TYPE
import requests
@ -62,7 +62,7 @@ class API:
if port is not None:
self.base_url += ':{}'.format(port)
self.status = None
self.status = None # type: Optional[APIStatus]
self._headers = {CONTENT_TYPE: CONTENT_TYPE_JSON}
if api_password is not None:
@ -75,20 +75,24 @@ class API:
return self.status == APIStatus.OK
def __call__(self, method, path, data=None, timeout=5):
def __call__(self, method: str, path: str, data: Dict = None,
timeout: int = 5) -> requests.Response:
"""Make a call to the Home Assistant API."""
if data is not None:
data = json.dumps(data, cls=JSONEncoder)
if data is None:
data_str = None
else:
data_str = json.dumps(data, cls=JSONEncoder)
url = urllib.parse.urljoin(self.base_url, path)
try:
if method == METH_GET:
return requests.get(
url, params=data, timeout=timeout, headers=self._headers)
url, params=data_str, timeout=timeout,
headers=self._headers)
return requests.request(
method, url, data=data, timeout=timeout,
method, url, data=data_str, timeout=timeout,
headers=self._headers)
except requests.exceptions.ConnectionError:
@ -110,7 +114,7 @@ class JSONEncoder(json.JSONEncoder):
"""JSONEncoder that supports Home Assistant objects."""
# pylint: disable=method-hidden
def default(self, o):
def default(self, o: Any) -> Any:
"""Convert Home Assistant objects.
Hand other objects to the original method.
@ -125,7 +129,7 @@ class JSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o)
def validate_api(api):
def validate_api(api: API) -> APIStatus:
"""Make a call to validate API."""
try:
req = api(METH_GET, URL_API)
@ -142,12 +146,12 @@ def validate_api(api):
return APIStatus.CANNOT_CONNECT
def get_event_listeners(api):
def get_event_listeners(api: API) -> Dict:
"""List of events that is being listened for."""
try:
req = api(METH_GET, URL_API_EVENTS)
return req.json() if req.status_code == 200 else {}
return req.json() if req.status_code == 200 else {} # type: ignore
except (HomeAssistantError, ValueError):
# ValueError if req.json() can't parse the json
@ -156,7 +160,7 @@ def get_event_listeners(api):
return {}
def fire_event(api, event_type, data=None):
def fire_event(api: API, event_type: str, data: Dict = None) -> None:
"""Fire an event at remote API."""
try:
req = api(METH_POST, URL_API_EVENTS_EVENT.format(event_type), data)
@ -169,7 +173,7 @@ def fire_event(api, event_type, data=None):
_LOGGER.exception("Error firing event")
def get_state(api, entity_id):
def get_state(api: API, entity_id: str) -> Optional[ha.State]:
"""Query given API for state of entity_id."""
try:
req = api(METH_GET, URL_API_STATES_ENTITY.format(entity_id))
@ -186,7 +190,7 @@ def get_state(api, entity_id):
return None
def get_states(api):
def get_states(api: API) -> List[ha.State]:
"""Query given API for all states."""
try:
req = api(METH_GET,
@ -202,7 +206,7 @@ def get_states(api):
return []
def remove_state(api, entity_id):
def remove_state(api: API, entity_id: str) -> bool:
"""Call API to remove state for entity_id.
Return True if entity is gone (removed/never existed).
@ -222,7 +226,8 @@ def remove_state(api, entity_id):
return False
def set_state(api, entity_id, new_state, attributes=None, force_update=False):
def set_state(api: API, entity_id: str, new_state: str,
attributes: Dict = None, force_update: bool = False) -> bool:
"""Tell API to update state for entity_id.
Return True if success.
@ -249,14 +254,14 @@ def set_state(api, entity_id, new_state, attributes=None, force_update=False):
return False
def is_state(api, entity_id, state):
def is_state(api: API, entity_id: str, state: str) -> bool:
"""Query API to see if entity_id is specified state."""
cur_state = get_state(api, entity_id)
return cur_state and cur_state.state == state
return bool(cur_state and cur_state.state == state)
def get_services(api):
def get_services(api: API) -> Dict:
"""Return a list of dicts.
Each dict has a string "domain" and a list of strings "services".
@ -264,7 +269,7 @@ def get_services(api):
try:
req = api(METH_GET, URL_API_SERVICES)
return req.json() if req.status_code == 200 else {}
return req.json() if req.status_code == 200 else {} # type: ignore
except (HomeAssistantError, ValueError):
# ValueError if req.json() can't parse the json
@ -273,7 +278,9 @@ def get_services(api):
return {}
def call_service(api, domain, service, service_data=None, timeout=5):
def call_service(api: API, domain: str, service: str,
service_data: Dict = None,
timeout: int = 5) -> None:
"""Call a service at the remote API."""
try:
req = api(METH_POST,
@ -288,7 +295,7 @@ def call_service(api, domain, service, service_data=None, timeout=5):
_LOGGER.exception("Error calling service")
def get_config(api):
def get_config(api: API) -> Dict:
"""Return configuration."""
try:
req = api(METH_GET, URL_API_CONFIG)
@ -299,7 +306,7 @@ def get_config(api):
result = req.json()
if 'components' in result:
result['components'] = set(result['components'])
return result
return result # type: ignore
except (HomeAssistantError, ValueError):
# ValueError if req.json() can't parse the JSON

View File

@ -3,15 +3,18 @@ import asyncio
from functools import partial
import logging
import os
from typing import List, Dict, Optional
import homeassistant.util.package as pkg_util
from homeassistant.core import HomeAssistant
DATA_PIP_LOCK = 'pip_lock'
CONSTRAINT_FILE = 'package_constraints.txt'
_LOGGER = logging.getLogger(__name__)
async def async_process_requirements(hass, name, requirements):
async def async_process_requirements(hass: HomeAssistant, name: str,
requirements: List[str]) -> bool:
"""Install the requirements for a component or platform.
This method is a coroutine.
@ -25,7 +28,7 @@ async def async_process_requirements(hass, name, requirements):
async with pip_lock:
for req in requirements:
ret = await hass.async_add_job(pip_install, req)
ret = await hass.async_add_executor_job(pip_install, req)
if not ret:
_LOGGER.error("Not initializing %s because could not install "
"requirement %s", name, req)
@ -34,11 +37,11 @@ async def async_process_requirements(hass, name, requirements):
return True
def pip_kwargs(config_dir):
def pip_kwargs(config_dir: Optional[str]) -> Dict[str, str]:
"""Return keyword arguments for PIP install."""
kwargs = {
'constraints': os.path.join(os.path.dirname(__file__), CONSTRAINT_FILE)
}
if not pkg_util.is_virtual_env():
if not (config_dir is None or pkg_util.is_virtual_env()):
kwargs['target'] = os.path.join(config_dir, 'deps')
return kwargs

View File

@ -4,7 +4,7 @@ import logging.handlers
from timeit import default_timer as timer
from types import ModuleType
from typing import Optional, Dict
from typing import Optional, Dict, List
from homeassistant import requirements, core, loader, config as conf_util
from homeassistant.config import async_notify_setup_error
@ -56,7 +56,9 @@ async def async_setup_component(hass: core.HomeAssistant, domain: str,
return await task # type: ignore
async def _async_process_dependencies(hass, config, name, dependencies):
async def _async_process_dependencies(
hass: core.HomeAssistant, config: Dict, name: str,
dependencies: List[str]) -> bool:
"""Ensure all dependencies are set up."""
blacklisted = [dep for dep in dependencies
if dep in loader.DEPENDENCY_BLACKLIST]
@ -88,12 +90,12 @@ async def _async_process_dependencies(hass, config, name, dependencies):
async def _async_setup_component(hass: core.HomeAssistant,
domain: str, config) -> bool:
domain: str, config: Dict) -> bool:
"""Set up a component for Home Assistant.
This method is a coroutine.
"""
def log_error(msg, link=True):
def log_error(msg: str, link: bool = True) -> None:
"""Log helper."""
_LOGGER.error("Setup failed for %s: %s", domain, msg)
async_notify_setup_error(hass, domain, link)
@ -181,7 +183,7 @@ async def _async_setup_component(hass: core.HomeAssistant,
return True
async def async_prepare_setup_platform(hass: core.HomeAssistant, config,
async def async_prepare_setup_platform(hass: core.HomeAssistant, config: Dict,
domain: str, platform_name: str) \
-> Optional[ModuleType]:
"""Load a platform and makes sure dependencies are setup.
@ -190,7 +192,7 @@ async def async_prepare_setup_platform(hass: core.HomeAssistant, config,
"""
platform_path = PLATFORM_FORMAT.format(domain, platform_name)
def log_error(msg):
def log_error(msg: str) -> None:
"""Log helper."""
_LOGGER.error("Unable to prepare setup for platform %s: %s",
platform_path, msg)
@ -217,7 +219,9 @@ async def async_prepare_setup_platform(hass: core.HomeAssistant, config,
return platform
async def async_process_deps_reqs(hass, config, name, module):
async def async_process_deps_reqs(
hass: core.HomeAssistant, config: Dict, name: str,
module: ModuleType) -> None:
"""Process all dependencies and requirements for a module.
Module is a Python module of either a component or platform.
@ -231,14 +235,14 @@ async def async_process_deps_reqs(hass, config, name, module):
if hasattr(module, 'DEPENDENCIES'):
dep_success = await _async_process_dependencies(
hass, config, name, module.DEPENDENCIES)
hass, config, name, module.DEPENDENCIES) # type: ignore
if not dep_success:
raise HomeAssistantError("Could not setup all dependencies.")
if not hass.config.skip_pip and hasattr(module, 'REQUIREMENTS'):
req_success = await requirements.async_process_requirements(
hass, name, module.REQUIREMENTS)
hass, name, module.REQUIREMENTS) # type: ignore
if not req_success:
raise HomeAssistantError("Could not install all requirements.")

View File

@ -1,9 +1,8 @@
"""Helper methods for various modules."""
import asyncio
from collections.abc import MutableSet
from datetime import datetime, timedelta
from itertools import chain
import threading
from datetime import datetime
import re
import enum
import socket
@ -14,12 +13,13 @@ from types import MappingProxyType
from unicodedata import normalize
from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa
Iterable, List, Mapping)
Iterable, List, Dict, Iterator, Coroutine, MutableSet)
from .dt import as_local, utcnow
T = TypeVar('T')
U = TypeVar('U')
ENUM_T = TypeVar('ENUM_T', bound=enum.Enum)
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
@ -91,7 +91,7 @@ def ensure_unique_string(preferred_string: str, current_strings:
# Taken from: http://stackoverflow.com/a/11735897
def get_local_ip():
def get_local_ip() -> str:
"""Try to determine the local IP address of the machine."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -99,7 +99,7 @@ def get_local_ip():
# Use Google Public DNS server to determine own IP
sock.connect(('8.8.8.8', 80))
return sock.getsockname()[0]
return sock.getsockname()[0] # type: ignore
except socket.error:
try:
return socket.gethostbyname(socket.gethostname())
@ -110,7 +110,7 @@ def get_local_ip():
# Taken from http://stackoverflow.com/a/23728630
def get_random_string(length=10):
def get_random_string(length: int = 10) -> str:
"""Return a random string with letters and digits."""
generator = random.SystemRandom()
source_chars = string.ascii_letters + string.digits
@ -121,59 +121,59 @@ def get_random_string(length=10):
class OrderedEnum(enum.Enum):
"""Taken from Python 3.4.0 docs."""
def __ge__(self, other):
def __ge__(self: ENUM_T, other: ENUM_T) -> bool:
"""Return the greater than element."""
if self.__class__ is other.__class__:
return self.value >= other.value
return bool(self.value >= other.value)
return NotImplemented
def __gt__(self, other):
def __gt__(self: ENUM_T, other: ENUM_T) -> bool:
"""Return the greater element."""
if self.__class__ is other.__class__:
return self.value > other.value
return bool(self.value > other.value)
return NotImplemented
def __le__(self, other):
def __le__(self: ENUM_T, other: ENUM_T) -> bool:
"""Return the lower than element."""
if self.__class__ is other.__class__:
return self.value <= other.value
return bool(self.value <= other.value)
return NotImplemented
def __lt__(self, other):
def __lt__(self: ENUM_T, other: ENUM_T) -> bool:
"""Return the lower element."""
if self.__class__ is other.__class__:
return self.value < other.value
return bool(self.value < other.value)
return NotImplemented
class OrderedSet(MutableSet):
class OrderedSet(MutableSet[T]):
"""Ordered set taken from http://code.activestate.com/recipes/576694/."""
def __init__(self, iterable=None):
def __init__(self, iterable: Iterable[T] = None) -> None:
"""Initialize the set."""
self.end = end = [] # type: List[Any]
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # type: Mapping[List, Any] # key --> [key, prev, next]
self.map = {} # type: Dict[T, List] # key --> [key, prev, next]
if iterable is not None:
self |= iterable
self |= iterable # type: ignore
def __len__(self):
def __len__(self) -> int:
"""Return the length of the set."""
return len(self.map)
def __contains__(self, key):
def __contains__(self, key: T) -> bool: # type: ignore
"""Check if key is in set."""
return key in self.map
# pylint: disable=arguments-differ
def add(self, key):
def add(self, key: T) -> None:
"""Add an element to the end of the set."""
if key not in self.map:
end = self.end
curr = end[1]
curr[2] = end[1] = self.map[key] = [key, curr, end]
def promote(self, key):
def promote(self, key: T) -> None:
"""Promote element to beginning of the set, add if not there."""
if key in self.map:
self.discard(key)
@ -183,14 +183,14 @@ class OrderedSet(MutableSet):
curr[2] = begin[1] = self.map[key] = [key, curr, begin]
# pylint: disable=arguments-differ
def discard(self, key):
def discard(self, key: T) -> None:
"""Discard an element from the set."""
if key in self.map:
key, prev_item, next_item = self.map.pop(key)
prev_item[2] = next_item
next_item[1] = prev_item
def __iter__(self):
def __iter__(self) -> Iterator[T]:
"""Iterate of the set."""
end = self.end
curr = end[2]
@ -198,7 +198,7 @@ class OrderedSet(MutableSet):
yield curr[0]
curr = curr[2]
def __reversed__(self):
def __reversed__(self) -> Iterator[T]:
"""Reverse the ordering."""
end = self.end
curr = end[1]
@ -207,7 +207,7 @@ class OrderedSet(MutableSet):
curr = curr[1]
# pylint: disable=arguments-differ
def pop(self, last=True):
def pop(self, last: bool = True) -> T:
"""Pop element of the end of the set.
Set last=False to pop from the beginning.
@ -216,20 +216,20 @@ class OrderedSet(MutableSet):
raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0]
self.discard(key)
return key
return key # type: ignore
def update(self, *args):
def update(self, *args: Any) -> None:
"""Add elements from args to the set."""
for item in chain(*args):
self.add(item)
def __repr__(self):
def __repr__(self) -> str:
"""Return the representation."""
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Return the comparison."""
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
@ -254,20 +254,21 @@ class Throttle:
Adds a datetime attribute `last_call` to the method.
"""
def __init__(self, min_time, limit_no_throttle=None):
def __init__(self, min_time: timedelta,
limit_no_throttle: timedelta = None) -> None:
"""Initialize the throttle."""
self.min_time = min_time
self.limit_no_throttle = limit_no_throttle
def __call__(self, method):
def __call__(self, method: Callable) -> Callable:
"""Caller for the throttle."""
# Make sure we return a coroutine if the method is async.
if asyncio.iscoroutinefunction(method):
async def throttled_value():
async def throttled_value() -> None:
"""Stand-in function for when real func is being throttled."""
return None
else:
def throttled_value():
def throttled_value() -> None: # type: ignore
"""Stand-in function for when real func is being throttled."""
return None
@ -288,14 +289,14 @@ class Throttle:
'.' not in method.__qualname__.split('.<locals>.')[-1])
@wraps(method)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
"""Wrap that allows wrapped to be called only once per min_time.
If we cannot acquire the lock, it is running so return None.
"""
# pylint: disable=protected-access
if hasattr(method, '__self__'):
host = method.__self__
host = getattr(method, '__self__')
elif is_func:
host = wrapper
else:
@ -318,7 +319,7 @@ class Throttle:
if force or utcnow() - throttle[1] > self.min_time:
result = method(*args, **kwargs)
throttle[1] = utcnow()
return result
return result # type: ignore
return throttled_value()
finally:

View File

@ -3,22 +3,25 @@ import concurrent.futures
import threading
import logging
from asyncio import coroutines
from asyncio.events import AbstractEventLoop
from asyncio.futures import Future
from asyncio import ensure_future
from typing import Any, Union, Coroutine, Callable, Generator
_LOGGER = logging.getLogger(__name__)
def _set_result_unless_cancelled(fut, result):
def _set_result_unless_cancelled(fut: Future, result: Any) -> None:
"""Set the result only if the Future was not cancelled."""
if fut.cancelled():
return
fut.set_result(result)
def _set_concurrent_future_state(concurr, source):
def _set_concurrent_future_state(
concurr: concurrent.futures.Future,
source: Union[concurrent.futures.Future, Future]) -> None:
"""Copy state from a future to a concurrent.futures.Future."""
assert source.done()
if source.cancelled():
@ -33,7 +36,8 @@ def _set_concurrent_future_state(concurr, source):
concurr.set_result(result)
def _copy_future_state(source, dest):
def _copy_future_state(source: Union[concurrent.futures.Future, Future],
dest: Union[concurrent.futures.Future, Future]) -> None:
"""Copy state from another Future.
The other Future may be a concurrent.futures.Future.
@ -53,7 +57,9 @@ def _copy_future_state(source, dest):
dest.set_result(result)
def _chain_future(source, destination):
def _chain_future(
source: Union[concurrent.futures.Future, Future],
destination: Union[concurrent.futures.Future, Future]) -> None:
"""Chain two futures so that when one completes, so does the other.
The result (or exception) of source will be copied to destination.
@ -74,20 +80,23 @@ def _chain_future(source, destination):
else:
dest_loop = None
def _set_state(future, other):
def _set_state(future: Union[concurrent.futures.Future, Future],
other: Union[concurrent.futures.Future, Future]) -> None:
if isinstance(future, Future):
_copy_future_state(other, future)
else:
_set_concurrent_future_state(future, other)
def _call_check_cancel(destination):
def _call_check_cancel(
destination: Union[concurrent.futures.Future, Future]) -> None:
if destination.cancelled():
if source_loop is None or source_loop is dest_loop:
source.cancel()
else:
source_loop.call_soon_threadsafe(source.cancel)
def _call_set_state(source):
def _call_set_state(
source: Union[concurrent.futures.Future, Future]) -> None:
if dest_loop is None or dest_loop is source_loop:
_set_state(destination, source)
else:
@ -97,7 +106,9 @@ def _chain_future(source, destination):
source.add_done_callback(_call_set_state)
def run_coroutine_threadsafe(coro, loop):
def run_coroutine_threadsafe(
coro: Union[Coroutine, Generator],
loop: AbstractEventLoop) -> concurrent.futures.Future:
"""Submit a coroutine object to a given event loop.
Return a concurrent.futures.Future to access the result.
@ -110,7 +121,7 @@ def run_coroutine_threadsafe(coro, loop):
raise TypeError('A coroutine object is required')
future = concurrent.futures.Future() # type: concurrent.futures.Future
def callback():
def callback() -> None:
"""Handle the call to the coroutine."""
try:
_chain_future(ensure_future(coro, loop=loop), future)
@ -125,7 +136,8 @@ def run_coroutine_threadsafe(coro, loop):
return future
def fire_coroutine_threadsafe(coro, loop):
def fire_coroutine_threadsafe(coro: Coroutine,
loop: AbstractEventLoop) -> None:
"""Submit a coroutine object to a given event loop.
This method does not provide a way to retrieve the result and
@ -139,7 +151,7 @@ def fire_coroutine_threadsafe(coro, loop):
if not coroutines.iscoroutine(coro):
raise TypeError('A coroutine object is required: %s' % coro)
def callback():
def callback() -> None:
"""Handle the firing of a coroutine."""
ensure_future(coro, loop=loop)
@ -147,7 +159,8 @@ def fire_coroutine_threadsafe(coro, loop):
return
def run_callback_threadsafe(loop, callback, *args):
def run_callback_threadsafe(loop: AbstractEventLoop, callback: Callable,
*args: Any) -> concurrent.futures.Future:
"""Submit a callback object to a given event loop.
Return a concurrent.futures.Future to access the result.
@ -158,7 +171,7 @@ def run_callback_threadsafe(loop, callback, *args):
future = concurrent.futures.Future() # type: concurrent.futures.Future
def run_callback():
def run_callback() -> None:
"""Run callback and store result."""
try:
future.set_result(callback(*args))

View File

@ -2,7 +2,7 @@
import math
import colorsys
from typing import Tuple
from typing import Tuple, List
# Official CSS3 colors from w3.org:
# https://www.w3.org/TR/2010/PR-css3-color-20101028/#html4
@ -162,7 +162,7 @@ COLORS = {
}
def color_name_to_rgb(color_name):
def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
"""Convert color name to RGB hex value."""
# COLORS map has no spaces in it, so make the color_name have no
# spaces in it as well for matching purposes
@ -305,7 +305,8 @@ def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
return (r, g, b)
def color_RGB_to_hsv(iR: int, iG: int, iB: int) -> Tuple[float, float, float]:
def color_RGB_to_hsv(
iR: float, iG: float, iB: float) -> Tuple[float, float, float]:
"""Convert an rgb color to its hsv representation.
Hue is scaled 0-360
@ -316,7 +317,7 @@ def color_RGB_to_hsv(iR: int, iG: int, iB: int) -> Tuple[float, float, float]:
return round(fHSV[0]*360, 3), round(fHSV[1]*100, 3), round(fHSV[2]*100, 3)
def color_RGB_to_hs(iR: int, iG: int, iB: int) -> Tuple[float, float]:
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> Tuple[float, float]:
"""Convert an rgb color to its hs representation."""
return color_RGB_to_hsv(iR, iG, iB)[:2]
@ -340,7 +341,7 @@ def color_hs_to_RGB(iH: float, iS: float) -> Tuple[int, int, int]:
def color_xy_to_hs(vX: float, vY: float) -> Tuple[float, float]:
"""Convert an xy color to its hs representation."""
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY))
return (h, s)
return h, s
def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]:
@ -348,8 +349,7 @@ def color_hs_to_xy(iH: float, iS: float) -> Tuple[float, float]:
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS))
def _match_max_scale(input_colors: Tuple[int, ...],
output_colors: Tuple[int, ...]) -> Tuple[int, ...]:
def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
"""Match the maximum value of the output to the input."""
max_in = max(input_colors)
max_out = max(output_colors)
@ -360,7 +360,7 @@ def _match_max_scale(input_colors: Tuple[int, ...],
return tuple(int(round(i * factor)) for i in output_colors)
def color_rgb_to_rgbw(r, g, b):
def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
"""Convert an rgb color to an rgbw representation."""
# Calculate the white channel as the minimum of input rgb channels.
# Subtract the white portion from the remaining rgb channels.
@ -369,25 +369,25 @@ def color_rgb_to_rgbw(r, g, b):
# Match the output maximum value to the input. This ensures the full
# channel range is used.
return _match_max_scale((r, g, b), rgbw)
return _match_max_scale((r, g, b), rgbw) # type: ignore
def color_rgbw_to_rgb(r, g, b, w):
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> Tuple[int, int, int]:
"""Convert an rgbw color to an rgb representation."""
# Add the white channel back into the rgb channels.
rgb = (r + w, g + w, b + w)
# Match the output maximum value to the input. This ensures the
# output doesn't overflow.
return _match_max_scale((r, g, b, w), rgb)
return _match_max_scale((r, g, b, w), rgb) # type: ignore
def color_rgb_to_hex(r, g, b):
def color_rgb_to_hex(r: int, g: int, b: int) -> str:
"""Return a RGB color from a hex color string."""
return '{0:02x}{1:02x}{2:02x}'.format(round(r), round(g), round(b))
def rgb_hex_to_rgb_list(hex_string):
def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
"""Return an RGB color value list from a hex color string."""
return [int(hex_string[i:i + len(hex_string) // 3], 16)
for i in range(0,
@ -395,12 +395,14 @@ def rgb_hex_to_rgb_list(hex_string):
len(hex_string) // 3)]
def color_temperature_to_hs(color_temperature_kelvin):
def color_temperature_to_hs(
color_temperature_kelvin: float) -> Tuple[float, float]:
"""Return an hs color from a color temperature in Kelvin."""
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
def color_temperature_to_rgb(color_temperature_kelvin):
def color_temperature_to_rgb(
color_temperature_kelvin: float) -> Tuple[float, float, float]:
"""
Return an RGB color from a color temperature in Kelvin.
@ -421,7 +423,7 @@ def color_temperature_to_rgb(color_temperature_kelvin):
blue = _get_blue(tmp_internal)
return (red, green, blue)
return red, green, blue
def _bound(color_component: float, minimum: float = 0,
@ -464,11 +466,11 @@ def _get_blue(temperature: float) -> float:
return _bound(blue)
def color_temperature_mired_to_kelvin(mired_temperature):
def color_temperature_mired_to_kelvin(mired_temperature: float) -> float:
"""Convert absolute mired shift to degrees kelvin."""
return math.floor(1000000 / mired_temperature)
def color_temperature_kelvin_to_mired(kelvin_temperature):
def color_temperature_kelvin_to_mired(kelvin_temperature: float) -> float:
"""Convert degrees kelvin to mired shift."""
return math.floor(1000000 / kelvin_temperature)

View File

@ -1,12 +1,14 @@
"""Decorator utility functions."""
from typing import Callable, TypeVar
CALLABLE_T = TypeVar('CALLABLE_T', bound=Callable)
class Registry(dict):
"""Registry of items."""
def register(self, name):
def register(self, name: str) -> Callable[[CALLABLE_T], CALLABLE_T]:
"""Return decorator to register item with a specific name."""
def decorator(func):
def decorator(func: CALLABLE_T) -> CALLABLE_T:
"""Register decorated function."""
self[name] = func
return func

View File

@ -71,14 +71,14 @@ def as_utc(dattim: dt.datetime) -> dt.datetime:
return dattim.astimezone(UTC)
def as_timestamp(dt_value):
def as_timestamp(dt_value: dt.datetime) -> float:
"""Convert a date/time into a unix time (seconds since 1970)."""
if hasattr(dt_value, "timestamp"):
parsed_dt = dt_value
parsed_dt = dt_value # type: Optional[dt.datetime]
else:
parsed_dt = parse_datetime(str(dt_value))
if not parsed_dt:
raise ValueError("not a valid date/time.")
if parsed_dt is None:
raise ValueError("not a valid date/time.")
return parsed_dt.timestamp()
@ -150,7 +150,7 @@ def parse_date(dt_str: str) -> Optional[dt.date]:
return None
def parse_time(time_str):
def parse_time(time_str: str) -> Optional[dt.time]:
"""Parse a time string (00:20:00) into Time object.
Return None if invalid.

View File

@ -38,7 +38,7 @@ def load_json(filename: str, default: Union[List, Dict, None] = None) \
return {} if default is None else default
def save_json(filename: str, data: Union[List, Dict]):
def save_json(filename: str, data: Union[List, Dict]) -> None:
"""Save JSON data to a file.
Returns True on success.

View File

@ -33,7 +33,7 @@ LocationInfo = collections.namedtuple(
'use_metric'])
def detect_location_info():
def detect_location_info() -> Optional[LocationInfo]:
"""Detect location information."""
data = _get_freegeoip()
@ -63,7 +63,7 @@ def distance(lat1: Optional[float], lon1: Optional[float],
return result * 1000
def elevation(latitude, longitude):
def elevation(latitude: float, longitude: float) -> int:
"""Return elevation for given latitude and longitude."""
try:
req = requests.get(

View File

@ -1,7 +1,9 @@
"""Logging utilities."""
import asyncio
from asyncio.events import AbstractEventLoop
import logging
import threading
from typing import Optional
from .async_ import run_coroutine_threadsafe
@ -9,12 +11,12 @@ from .async_ import run_coroutine_threadsafe
class HideSensitiveDataFilter(logging.Filter):
"""Filter API password calls."""
def __init__(self, text):
def __init__(self, text: str) -> None:
"""Initialize sensitive data filter."""
super().__init__()
self.text = text
def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
"""Hide sensitive data in messages."""
record.msg = record.msg.replace(self.text, '*******')
@ -25,7 +27,8 @@ class HideSensitiveDataFilter(logging.Filter):
class AsyncHandler:
"""Logging handler wrapper to add an async layer."""
def __init__(self, loop, handler):
def __init__(
self, loop: AbstractEventLoop, handler: logging.Handler) -> None:
"""Initialize async logging handler wrapper."""
self.handler = handler
self.loop = loop
@ -45,11 +48,11 @@ class AsyncHandler:
self._thread.start()
def close(self):
def close(self) -> None:
"""Wrap close to handler."""
self.emit(None)
async def async_close(self, blocking=False):
async def async_close(self, blocking: bool = False) -> None:
"""Close the handler.
When blocking=True, will wait till closed.
@ -60,7 +63,7 @@ class AsyncHandler:
while self._thread.is_alive():
await asyncio.sleep(0, loop=self.loop)
def emit(self, record):
def emit(self, record: Optional[logging.LogRecord]) -> None:
"""Process a record."""
ident = self.loop.__dict__.get("_thread_ident")
@ -71,11 +74,11 @@ class AsyncHandler:
else:
self.loop.call_soon_threadsafe(self._queue.put_nowait, record)
def __repr__(self):
def __repr__(self) -> str:
"""Return the string names."""
return str(self.handler)
def _process(self):
def _process(self) -> None:
"""Process log in a thread."""
while True:
record = run_coroutine_threadsafe(
@ -87,34 +90,34 @@ class AsyncHandler:
self.handler.emit(record)
def createLock(self):
def createLock(self) -> None:
"""Ignore lock stuff."""
pass
def acquire(self):
def acquire(self) -> None:
"""Ignore lock stuff."""
pass
def release(self):
def release(self) -> None:
"""Ignore lock stuff."""
pass
@property
def level(self):
def level(self) -> int:
"""Wrap property level to handler."""
return self.handler.level
@property
def formatter(self):
def formatter(self) -> Optional[logging.Formatter]:
"""Wrap property formatter to handler."""
return self.handler.formatter
@property
def name(self):
def name(self) -> str:
"""Wrap property set_name to handler."""
return self.handler.get_name()
return self.handler.get_name() # type: ignore
@name.setter
def name(self, name):
def name(self, name: str) -> None:
"""Wrap property get_name to handler."""
self.handler.name = name
self.handler.set_name(name) # type: ignore

View File

@ -16,7 +16,7 @@ _LOGGER = logging.getLogger(__name__)
INSTALL_LOCK = threading.Lock()
def is_virtual_env():
def is_virtual_env() -> bool:
"""Return if we run in a virtual environtment."""
# Check supports venv && virtualenv
return (getattr(sys, 'base_prefix', sys.prefix) != sys.prefix or

View File

@ -4,7 +4,7 @@ import ssl
import certifi
def client_context():
def client_context() -> ssl.SSLContext:
"""Return an SSL context for making requests."""
context = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH,
@ -13,7 +13,7 @@ def client_context():
return context
def server_context():
def server_context() -> ssl.SSLContext:
"""Return an SSL context following the Mozilla recommendations.
TLS configuration follows the best-practice guidelines specified here:

View File

@ -4,7 +4,7 @@ import os
import sys
import fnmatch
from collections import OrderedDict
from typing import Union, List, Dict
from typing import Union, List, Dict, Iterator, overload, TypeVar
import yaml
try:
@ -22,7 +22,10 @@ from homeassistant.exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__)
_SECRET_NAMESPACE = 'homeassistant'
SECRET_YAML = 'secrets.yaml'
__SECRET_CACHE = {} # type: Dict
__SECRET_CACHE = {} # type: Dict[str, JSON_TYPE]
JSON_TYPE = Union[List, Dict, str]
DICT_T = TypeVar('DICT_T', bound=Dict)
class NodeListClass(list):
@ -37,7 +40,42 @@ class NodeStrClass(str):
pass
def _add_reference(obj, loader, node):
# pylint: disable=too-many-ancestors
class SafeLineLoader(yaml.SafeLoader):
"""Loader class that keeps track of line numbers."""
def compose_node(self, parent: yaml.nodes.Node,
index: int) -> yaml.nodes.Node:
"""Annotate a node with the first line it was seen."""
last_line = self.line # type: int
node = super(SafeLineLoader,
self).compose_node(parent, index) # type: yaml.nodes.Node
node.__line__ = last_line + 1 # type: ignore
return node
# pylint: disable=pointless-statement
@overload
def _add_reference(obj: Union[list, NodeListClass],
loader: yaml.SafeLoader,
node: yaml.nodes.Node) -> NodeListClass: ...
@overload # noqa: F811
def _add_reference(obj: Union[str, NodeStrClass],
loader: yaml.SafeLoader,
node: yaml.nodes.Node) -> NodeStrClass: ...
@overload # noqa: F811
def _add_reference(obj: DICT_T,
loader: yaml.SafeLoader,
node: yaml.nodes.Node) -> DICT_T: ...
# pylint: enable=pointless-statement
def _add_reference(obj, loader: SafeLineLoader, # type: ignore # noqa: F811
node: yaml.nodes.Node):
"""Add file reference information to an object."""
if isinstance(obj, list):
obj = NodeListClass(obj)
@ -48,20 +86,7 @@ def _add_reference(obj, loader, node):
return obj
# pylint: disable=too-many-ancestors
class SafeLineLoader(yaml.SafeLoader):
"""Loader class that keeps track of line numbers."""
def compose_node(self, parent: yaml.nodes.Node, index) -> yaml.nodes.Node:
"""Annotate a node with the first line it was seen."""
last_line = self.line # type: int
node = super(SafeLineLoader,
self).compose_node(parent, index) # type: yaml.nodes.Node
node.__line__ = last_line + 1 # type: ignore
return node
def load_yaml(fname: str) -> Union[List, Dict]:
def load_yaml(fname: str) -> JSON_TYPE:
"""Load a YAML file."""
try:
with open(fname, encoding='utf-8') as conf_file:
@ -83,12 +108,12 @@ def dump(_dict: dict) -> str:
.replace(': null\n', ':\n')
def save_yaml(path, data):
def save_yaml(path: str, data: dict) -> None:
"""Save YAML to a file."""
# Dump before writing to not truncate the file if dumping fails
data = dump(data)
str_data = dump(data)
with open(path, 'w', encoding='utf-8') as outfile:
outfile.write(data)
outfile.write(str_data)
def clear_secret_cache() -> None:
@ -100,7 +125,7 @@ def clear_secret_cache() -> None:
def _include_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node) -> Union[List, Dict]:
node: yaml.nodes.Node) -> JSON_TYPE:
"""Load another YAML file and embeds it using the !include tag.
Example:
@ -115,7 +140,7 @@ def _is_file_valid(name: str) -> bool:
return not name.startswith('.')
def _find_files(directory: str, pattern: str):
def _find_files(directory: str, pattern: str) -> Iterator[str]:
"""Recursively load files in a directory."""
for root, dirs, files in os.walk(directory, topdown=True):
dirs[:] = [d for d in dirs if _is_file_valid(d)]
@ -151,7 +176,7 @@ def _include_dir_merge_named_yaml(loader: SafeLineLoader,
def _include_dir_list_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
node: yaml.nodes.Node) -> List[JSON_TYPE]:
"""Load multiple files from directory as a list."""
loc = os.path.join(os.path.dirname(loader.name), node.value)
return [load_yaml(f) for f in _find_files(loc, '*.yaml')
@ -159,11 +184,11 @@ def _include_dir_list_yaml(loader: SafeLineLoader,
def _include_dir_merge_list_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
node: yaml.nodes.Node) -> JSON_TYPE:
"""Load multiple files from directory as a merged list."""
loc = os.path.join(os.path.dirname(loader.name),
node.value) # type: str
merged_list = [] # type: List
merged_list = [] # type: List[JSON_TYPE]
for fname in _find_files(loc, '*.yaml'):
if os.path.basename(fname) == SECRET_YAML:
continue
@ -202,14 +227,14 @@ def _ordered_dict(loader: SafeLineLoader,
return _add_reference(OrderedDict(nodes), loader, node)
def _construct_seq(loader: SafeLineLoader, node: yaml.nodes.Node):
def _construct_seq(loader: SafeLineLoader, node: yaml.nodes.Node) -> JSON_TYPE:
"""Add line number and file name to Load YAML sequence."""
obj, = loader.construct_yaml_seq(node)
return _add_reference(obj, loader, node)
def _env_var_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
node: yaml.nodes.Node) -> str:
"""Load environment variables and embed it into the configuration YAML."""
args = node.value.split()
@ -222,7 +247,7 @@ def _env_var_yaml(loader: SafeLineLoader,
raise HomeAssistantError(node.value)
def _load_secret_yaml(secret_path: str) -> Dict:
def _load_secret_yaml(secret_path: str) -> JSON_TYPE:
"""Load the secrets yaml from path."""
secret_path = os.path.join(secret_path, SECRET_YAML)
if secret_path in __SECRET_CACHE:
@ -248,7 +273,7 @@ def _load_secret_yaml(secret_path: str) -> Dict:
def _secret_yaml(loader: SafeLineLoader,
node: yaml.nodes.Node):
node: yaml.nodes.Node) -> JSON_TYPE:
"""Load secrets and embed it into the configuration YAML."""
secret_path = os.path.dirname(loader.name)
while True:
@ -308,7 +333,8 @@ yaml.SafeLoader.add_constructor('!include_dir_merge_named',
# From: https://gist.github.com/miracle2k/3184458
# pylint: disable=redefined-outer-name
def represent_odict(dump, tag, mapping, flow_style=None):
def represent_odict(dump, tag, mapping, # type: ignore
flow_style=None) -> yaml.MappingNode:
"""Like BaseRepresenter.represent_mapping but does not issue the sort()."""
value = [] # type: list
node = yaml.MappingNode(tag, value, flow_style=flow_style)

View File

@ -2,11 +2,18 @@
check_untyped_defs = true
follow_imports = silent
ignore_missing_imports = true
warn_incomplete_stub = true
warn_redundant_casts = true
warn_return_any = true
warn_unused_configs = true
warn_unused_ignores = true
[mypy-homeassistant.*]
disallow_untyped_defs = true
[mypy-homeassistant.config_entries]
disallow_untyped_defs = false
[mypy-homeassistant.util.yaml]
warn_return_any = false

View File

@ -437,10 +437,12 @@ class TestAutomation(unittest.TestCase):
}
}
}}):
automation.reload(self.hass)
self.hass.block_till_done()
# De-flake ?!
self.hass.block_till_done()
with patch('homeassistant.config.find_config_file',
return_value=''):
automation.reload(self.hass)
self.hass.block_till_done()
# De-flake ?!
self.hass.block_till_done()
assert self.hass.states.get('automation.hello') is None
assert self.hass.states.get('automation.bye') is not None
@ -485,8 +487,10 @@ class TestAutomation(unittest.TestCase):
with patch('homeassistant.config.load_yaml_config_file', autospec=True,
return_value={automation.DOMAIN: 'not valid'}):
automation.reload(self.hass)
self.hass.block_till_done()
with patch('homeassistant.config.find_config_file',
return_value=''):
automation.reload(self.hass)
self.hass.block_till_done()
assert self.hass.states.get('automation.hello') is None
@ -521,8 +525,10 @@ class TestAutomation(unittest.TestCase):
with patch('homeassistant.config.load_yaml_config_file',
side_effect=HomeAssistantError('bla')):
automation.reload(self.hass)
self.hass.block_till_done()
with patch('homeassistant.config.find_config_file',
return_value=''):
automation.reload(self.hass)
self.hass.block_till_done()
assert self.hass.states.get('automation.hello') is not None

View File

@ -365,8 +365,10 @@ class TestComponentsGroup(unittest.TestCase):
'icon': 'mdi:work',
'view': True,
}}}):
group.reload(self.hass)
self.hass.block_till_done()
with patch('homeassistant.config.find_config_file',
return_value=''):
group.reload(self.hass)
self.hass.block_till_done()
assert sorted(self.hass.states.entity_ids()) == \
['group.all_tests', 'group.hello']

View File

@ -199,8 +199,10 @@ class TestScriptComponent(unittest.TestCase):
}
}]
}}}):
script.reload(self.hass)
self.hass.block_till_done()
with patch('homeassistant.config.find_config_file',
return_value=''):
script.reload(self.hass)
self.hass.block_till_done()
assert self.hass.states.get(ENTITY_ID) is None
assert not self.hass.services.has_service(script.DOMAIN, 'test')