1
mirror of https://github.com/home-assistant/core synced 2024-10-04 07:58:43 +02:00

Spread async love (#3575)

* Convert Entity.update_ha_state to be async

* Make Service.call async

* Update entity.py

* Add Entity.async_update

* Make automation zone trigger async

* Fix linting

* Reduce flakiness in hass.block_till_done

* Make automation.numeric_state async

* Make mqtt.subscribe async

* Make automation.mqtt async

* Make automation.time async

* Make automation.sun async

* Add async_track_point_in_utc_time

* Make helpers.track_sunrise/set async

* Add async_track_state_change

* Make automation.state async

* Clean up helpers/entity.py tests

* Lint

* Lint

* Core.is_state and Core.is_state_attr are async friendly

* Lint

* Lint
This commit is contained in:
Paulus Schoutsen 2016-09-30 12:57:24 -07:00 committed by GitHub
parent 7e50ccd32a
commit b650b2b0db
17 changed files with 323 additions and 151 deletions

View File

@ -4,6 +4,7 @@ Offer MQTT listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#mqtt-trigger
"""
import asyncio
import voluptuous as vol
import homeassistant.components.mqtt as mqtt
@ -26,10 +27,11 @@ def trigger(hass, config, action):
topic = config.get(CONF_TOPIC)
payload = config.get(CONF_PAYLOAD)
@asyncio.coroutine
def mqtt_automation_listener(msg_topic, msg_payload, qos):
"""Listen for MQTT messages."""
if payload is None or payload == msg_payload:
action({
hass.async_add_job(action, {
'trigger': {
'platform': 'mqtt',
'topic': msg_topic,

View File

@ -4,6 +4,7 @@ Offer numeric state listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#numeric-state-trigger
"""
import asyncio
import logging
import voluptuous as vol
@ -34,7 +35,7 @@ def trigger(hass, config, action):
if value_template is not None:
value_template.hass = hass
# pylint: disable=unused-argument
@asyncio.coroutine
def state_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action."""
if to_s is None:
@ -50,19 +51,19 @@ def trigger(hass, config, action):
}
# If new one doesn't match, nothing to do
if not condition.numeric_state(
if not condition.async_numeric_state(
hass, to_s, below, above, value_template, variables):
return
# Only match if old didn't exist or existed but didn't match
# Written as: skip if old one did exist and matched
if from_s is not None and condition.numeric_state(
if from_s is not None and condition.async_numeric_state(
hass, from_s, below, above, value_template, variables):
return
variables['trigger']['from_state'] = from_s
variables['trigger']['to_state'] = to_s
action(variables)
hass.async_add_job(action, variables)
return track_state_change(hass, entity_id, state_automation_listener)

View File

@ -4,12 +4,15 @@ Offer state listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#state-trigger
"""
import asyncio
import voluptuous as vol
import homeassistant.util.dt as dt_util
from homeassistant.const import MATCH_ALL, CONF_PLATFORM
from homeassistant.helpers.event import track_state_change, track_point_in_time
from homeassistant.helpers.event import (
async_track_state_change, async_track_point_in_utc_time)
import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_callback_threadsafe
CONF_ENTITY_ID = "entity_id"
CONF_FROM = "from"
@ -38,16 +41,17 @@ def trigger(hass, config, action):
from_state = config.get(CONF_FROM, MATCH_ALL)
to_state = config.get(CONF_TO) or config.get(CONF_STATE) or MATCH_ALL
time_delta = config.get(CONF_FOR)
remove_state_for_cancel = None
remove_state_for_listener = None
async_remove_state_for_cancel = None
async_remove_state_for_listener = None
@asyncio.coroutine
def state_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action."""
nonlocal remove_state_for_cancel, remove_state_for_listener
nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
def call_action():
"""Call action with right context."""
action({
hass.async_add_job(action, {
'trigger': {
'platform': 'state',
'entity_id': entity,
@ -61,35 +65,41 @@ def trigger(hass, config, action):
call_action()
return
@asyncio.coroutine
def state_for_listener(now):
"""Fire on state changes after a delay and calls action."""
remove_state_for_cancel()
async_remove_state_for_cancel()
call_action()
@asyncio.coroutine
def state_for_cancel_listener(entity, inner_from_s, inner_to_s):
"""Fire on changes and cancel for listener if changed."""
if inner_to_s.state == to_s.state:
return
remove_state_for_listener()
remove_state_for_cancel()
async_remove_state_for_listener()
async_remove_state_for_cancel()
remove_state_for_listener = track_point_in_time(
async_remove_state_for_listener = async_track_point_in_utc_time(
hass, state_for_listener, dt_util.utcnow() + time_delta)
remove_state_for_cancel = track_state_change(
async_remove_state_for_cancel = async_track_state_change(
hass, entity, state_for_cancel_listener)
unsub = track_state_change(hass, entity_id, state_automation_listener,
from_state, to_state)
unsub = async_track_state_change(
hass, entity_id, state_automation_listener, from_state, to_state)
def async_remove():
"""Remove state listeners async."""
unsub()
# pylint: disable=not-callable
if async_remove_state_for_cancel is not None:
async_remove_state_for_cancel()
if async_remove_state_for_listener is not None:
async_remove_state_for_listener()
def remove():
"""Remove state listeners."""
unsub()
# pylint: disable=not-callable
if remove_state_for_cancel is not None:
remove_state_for_cancel()
if remove_state_for_listener is not None:
remove_state_for_listener()
run_callback_threadsafe(hass.loop, async_remove).result()
return remove

View File

@ -4,6 +4,7 @@ Offer sun based automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#sun-trigger
"""
import asyncio
from datetime import timedelta
import logging
@ -30,9 +31,10 @@ def trigger(hass, config, action):
event = config.get(CONF_EVENT)
offset = config.get(CONF_OFFSET)
@asyncio.coroutine
def call_action():
"""Call action with right context."""
action({
hass.async_add_job(action, {
'trigger': {
'platform': 'sun',
'event': event,

View File

@ -4,6 +4,7 @@ Offer time listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#time-trigger
"""
import asyncio
import logging
import voluptuous as vol
@ -38,9 +39,10 @@ def trigger(hass, config, action):
minutes = config.get(CONF_MINUTES)
seconds = config.get(CONF_SECONDS)
@asyncio.coroutine
def time_automation_listener(now):
"""Listen for time changes and calls action."""
action({
hass.async_add_job(action, {
'trigger': {
'platform': 'time',
'now': now,

View File

@ -4,6 +4,7 @@ Offer zone automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#zone-trigger
"""
import asyncio
import voluptuous as vol
from homeassistant.const import (
@ -31,6 +32,7 @@ def trigger(hass, config, action):
zone_entity_id = config.get(CONF_ZONE)
event = config.get(CONF_EVENT)
@asyncio.coroutine
def zone_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action."""
if from_s and not location.has_location(from_s) or \
@ -47,7 +49,7 @@ def trigger(hass, config, action):
# pylint: disable=too-many-boolean-expressions
if event == EVENT_ENTER and not from_match and to_match or \
event == EVENT_LEAVE and from_match and not to_match:
action({
hass.async_add_job(action, {
'trigger': {
'platform': 'zone',
'entity_id': entity,

View File

@ -4,6 +4,7 @@ Event parser and human readable log generator.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/logbook/
"""
import asyncio
import logging
from datetime import timedelta
from itertools import groupby
@ -20,6 +21,7 @@ from homeassistant.const import (EVENT_HOMEASSISTANT_START,
STATE_NOT_HOME, STATE_OFF, STATE_ON,
ATTR_HIDDEN)
from homeassistant.core import State, split_entity_id, DOMAIN as HA_DOMAIN
from homeassistant.util.async import run_callback_threadsafe
DOMAIN = "logbook"
DEPENDENCIES = ['recorder', 'frontend']
@ -57,6 +59,13 @@ LOG_MESSAGE_SCHEMA = vol.Schema({
def log_entry(hass, name, message, domain=None, entity_id=None):
"""Add an entry to the logbook."""
run_callback_threadsafe(
hass.loop, async_log_entry, hass, name, message, domain, entity_id
).result()
def async_log_entry(hass, name, message, domain=None, entity_id=None):
"""Add an entry to the logbook."""
data = {
ATTR_NAME: name,
@ -67,11 +76,12 @@ def log_entry(hass, name, message, domain=None, entity_id=None):
data[ATTR_DOMAIN] = domain
if entity_id is not None:
data[ATTR_ENTITY_ID] = entity_id
hass.bus.fire(EVENT_LOGBOOK_ENTRY, data)
hass.bus.async_fire(EVENT_LOGBOOK_ENTRY, data)
def setup(hass, config):
"""Listen for download events to download files."""
@asyncio.coroutine
def log_message(service):
"""Handle sending notification message service calls."""
message = service.data[ATTR_MESSAGE]
@ -80,8 +90,8 @@ def setup(hass, config):
entity_id = service.data.get(ATTR_ENTITY_ID)
message.hass = hass
message = message.render()
log_entry(hass, name, message, domain, entity_id)
message = message.async_render()
async_log_entry(hass, name, message, domain, entity_id)
hass.wsgi.register_view(LogbookView(hass, config))

View File

@ -4,6 +4,7 @@ Support for MQTT message handling.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/
"""
import asyncio
import logging
import os
import socket
@ -11,6 +12,7 @@ import time
import voluptuous as vol
from homeassistant.core import JobPriority
from homeassistant.bootstrap import prepare_setup_platform
from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError
@ -164,11 +166,20 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
def subscribe(hass, topic, callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic."""
@asyncio.coroutine
def mqtt_topic_subscriber(event):
"""Match subscribed MQTT topic."""
if _match_topic(topic, event.data[ATTR_TOPIC]):
callback(event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD],
event.data[ATTR_QOS])
if not _match_topic(topic, event.data[ATTR_TOPIC]):
return
if asyncio.iscoroutinefunction(callback):
yield from callback(
event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD],
event.data[ATTR_QOS])
else:
hass.add_job(callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS],
priority=JobPriority.EVENT_CALLBACK)
remove = hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED,
mqtt_topic_subscriber)

View File

@ -248,12 +248,16 @@ class HomeAssistant(object):
def notify_when_done():
"""Notify event loop when pool done."""
count = 0
while True:
# Wait for the work queue to empty
self.pool.block_till_done()
# Verify the loop is empty
if self._loop_empty():
count += 1
if count == 2:
break
# sleep in the loop executor, this forces execution back into
@ -675,40 +679,29 @@ class StateMachine(object):
return list(self._states.values())
def get(self, entity_id):
"""Retrieve state of entity_id or None if not found."""
"""Retrieve state of entity_id or None if not found.
Async friendly.
"""
return self._states.get(entity_id.lower())
def is_state(self, entity_id, state):
"""Test if entity exists and is specified state."""
return run_callback_threadsafe(
self._loop, self.async_is_state, entity_id, state
).result()
def async_is_state(self, entity_id, state):
"""Test if entity exists and is specified state.
This method must be run in the event loop.
Async friendly.
"""
entity_id = entity_id.lower()
state_obj = self.get(entity_id)
return (entity_id in self._states and
self._states[entity_id].state == state)
return state_obj and state_obj.state == state
def is_state_attr(self, entity_id, name, value):
"""Test if entity exists and has a state attribute set to value."""
return run_callback_threadsafe(
self._loop, self.async_is_state_attr, entity_id, name, value
).result()
def async_is_state_attr(self, entity_id, name, value):
"""Test if entity exists and has a state attribute set to value.
This method must be run in the event loop.
Async friendly.
"""
entity_id = entity_id.lower()
state_obj = self.get(entity_id)
return (entity_id in self._states and
self._states[entity_id].attributes.get(name, None) == value)
return state_obj and state_obj.attributes.get(name, None) == value
def remove(self, entity_id):
"""Remove the state of an entity.
@ -799,7 +792,8 @@ class StateMachine(object):
class Service(object):
"""Represents a callable service."""
__slots__ = ['func', 'description', 'fields', 'schema']
__slots__ = ['func', 'description', 'fields', 'schema',
'iscoroutinefunction']
def __init__(self, func, description, fields, schema):
"""Initialize a service."""
@ -807,6 +801,7 @@ class Service(object):
self.description = description or ''
self.fields = fields or {}
self.schema = schema
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
def as_dict(self):
"""Return dictionary representation of this service."""
@ -815,19 +810,6 @@ class Service(object):
'fields': self.fields,
}
def __call__(self, call):
"""Execute the service."""
try:
if self.schema:
call.data = self.schema(call.data)
call.data = MappingProxyType(call.data)
self.func(call)
except vol.MultipleInvalid as ex:
_LOGGER.error('Invalid service data for %s.%s: %s',
call.domain, call.service,
humanize_error(call.data, ex))
# pylint: disable=too-few-public-methods
class ServiceCall(object):
@ -839,7 +821,7 @@ class ServiceCall(object):
"""Initialize a service call."""
self.domain = domain.lower()
self.service = service.lower()
self.data = data or {}
self.data = MappingProxyType(data or {})
self.call_id = call_id
def __repr__(self):
@ -983,9 +965,9 @@ class ServiceRegistry(object):
fut = asyncio.Future(loop=self._loop)
@asyncio.coroutine
def service_executed(call):
def service_executed(event):
"""Callback method that is called when service is executed."""
if call.data[ATTR_SERVICE_CALL_ID] == call_id:
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
fut.set_result(True)
unsub = self._bus.async_listen(EVENT_SERVICE_EXECUTED,
@ -1000,9 +982,10 @@ class ServiceRegistry(object):
unsub()
return success
@asyncio.coroutine
def _event_to_service_call(self, event):
"""Callback for SERVICE_CALLED events from the event bus."""
service_data = event.data.get(ATTR_SERVICE_DATA)
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
domain = event.data.get(ATTR_DOMAIN).lower()
service = event.data.get(ATTR_SERVICE).lower()
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
@ -1014,19 +997,41 @@ class ServiceRegistry(object):
return
service_handler = self._services[domain][service]
def fire_service_executed():
"""Fire service executed event."""
if not call_id:
return
data = {ATTR_SERVICE_CALL_ID: call_id}
if service_handler.iscoroutinefunction:
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
else:
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
try:
if service_handler.schema:
service_data = service_handler.schema(service_data)
except vol.Invalid as ex:
_LOGGER.error('Invalid service data for %s.%s: %s',
domain, service, humanize_error(service_data, ex))
fire_service_executed()
return
service_call = ServiceCall(domain, service, service_data, call_id)
# Add a job to the pool that calls _execute_service
self._add_job(self._execute_service, service_handler, service_call,
priority=JobPriority.EVENT_SERVICE)
if not service_handler.iscoroutinefunction:
def execute_service():
"""Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call)
fire_service_executed()
def _execute_service(self, service, call):
"""Execute a service and fires a SERVICE_EXECUTED event."""
service(call)
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
return
if call.call_id is not None:
self._bus.fire(
EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id})
yield from service_handler.func(service_call)
fire_service_executed()
def _generate_unique_id(self):
"""Generate a unique service call id."""

View File

@ -84,6 +84,15 @@ def or_from_config(config: ConfigType, config_validation: bool=True):
def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None):
"""Test a numeric state condition."""
return run_callback_threadsafe(
hass.loop, async_numeric_state, hass, entity, below, above,
value_template, variables,
).result()
def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None):
"""Test a numeric state condition."""
if isinstance(entity, str):
entity = hass.states.get(entity)
@ -96,7 +105,7 @@ def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
variables = dict(variables or {})
variables['state'] = entity
try:
value = value_template.render(variables)
value = value_template.async_render(variables)
except TemplateError as ex:
_LOGGER.error("Template error: %s", ex)
return False
@ -290,7 +299,10 @@ def time_from_config(config, config_validation=True):
def zone(hass, zone_ent, entity):
"""Test if zone-condition matches."""
"""Test if zone-condition matches.
Can be run async.
"""
if isinstance(zone_ent, str):
zone_ent = hass.states.get(zone_ent)

View File

@ -1,4 +1,5 @@
"""An abstract class for entities."""
import asyncio
import logging
from typing import Any, Optional, List, Dict
@ -11,6 +12,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify
from homeassistant.util.async import run_coroutine_threadsafe
# Entity attributes that we will overwrite
_OVERWRITE = {} # type: Dict[str, Any]
@ -143,6 +145,23 @@ class Entity(object):
If force_refresh == True will update entity before setting state.
"""
# We're already in a thread, do the force refresh here.
if force_refresh and not hasattr(self, 'async_update'):
self.update()
force_refresh = False
run_coroutine_threadsafe(
self.async_update_ha_state(force_refresh), self.hass.loop
).result()
@asyncio.coroutine
def async_update_ha_state(self, force_refresh=False):
"""Update Home Assistant with current state of entity.
If force_refresh == True will update entity before setting state.
This method must be run in the event loop.
"""
if self.hass is None:
raise RuntimeError("Attribute hass is None for {}".format(self))
@ -151,7 +170,13 @@ class Entity(object):
"No entity id specified for entity {}".format(self.name))
if force_refresh:
self.update()
if hasattr(self, 'async_update'):
# pylint: disable=no-member
self.async_update()
else:
# PS: Run this in our own thread pool once we have
# future support?
yield from self.hass.loop.run_in_executor(None, self.update)
state = STATE_UNKNOWN if self.state is None else str(self.state)
attr = self.state_attributes or {}
@ -192,7 +217,7 @@ class Entity(object):
# Could not convert state to float
pass
return self.hass.states.set(
self.hass.states.async_set(
self.entity_id, state, attr, self.force_update)
def remove(self) -> None:

View File

@ -18,6 +18,28 @@ def track_state_change(hass, entity_ids, action, from_state=None,
Returns a function that can be called to remove the listener.
"""
async_unsub = run_callback_threadsafe(
hass.loop, async_track_state_change, hass, entity_ids, action,
from_state, to_state).result()
def remove():
"""Remove listener."""
run_callback_threadsafe(hass.loop, async_unsub).result()
return remove
def async_track_state_change(hass, entity_ids, action, from_state=None,
to_state=None):
"""Track specific state changes.
entity_ids, from_state and to_state can be string or list.
Use list to match multiple.
Returns a function that can be called to remove the listener.
Must be run within the event loop.
"""
from_state = _process_state_match(from_state)
to_state = _process_state_match(to_state)
@ -52,7 +74,7 @@ def track_state_change(hass, entity_ids, action, from_state=None,
event.data.get('old_state'),
event.data.get('new_state'))
return hass.bus.listen(EVENT_STATE_CHANGED, state_change_listener)
return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener)
def track_point_in_time(hass, action, point_in_time):
@ -69,6 +91,19 @@ def track_point_in_time(hass, action, point_in_time):
def track_point_in_utc_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in UTC time."""
async_unsub = run_callback_threadsafe(
hass.loop, async_track_point_in_utc_time, hass, action, point_in_time
).result()
def remove():
"""Remove listener."""
run_callback_threadsafe(hass.loop, async_unsub).result()
return remove
def async_track_point_in_utc_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in UTC time."""
# Ensure point_in_time is UTC
point_in_time = dt_util.as_utc(point_in_time)
@ -88,20 +123,14 @@ def track_point_in_utc_time(hass, action, point_in_time):
# listener gets lined up twice to be executed. This will make
# sure the second time it does nothing.
point_in_time_listener.run = True
async_remove()
async_unsub()
hass.async_add_job(action, now)
future = run_callback_threadsafe(
hass.loop, hass.bus.async_listen, EVENT_TIME_CHANGED,
point_in_time_listener)
async_remove = future.result()
async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED,
point_in_time_listener)
def remove():
"""Remove listener."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove
return async_unsub
def track_sunrise(hass, action, offset=None):
@ -118,19 +147,21 @@ def track_sunrise(hass, action, offset=None):
return next_time
@asyncio.coroutine
def sunrise_automation_listener(now):
"""Called when it's time for action."""
nonlocal remove
remove = track_point_in_utc_time(hass, sunrise_automation_listener,
next_rise())
action()
remove = async_track_point_in_utc_time(
hass, sunrise_automation_listener, next_rise())
hass.async_add_job(action)
remove = track_point_in_utc_time(hass, sunrise_automation_listener,
next_rise())
remove = run_callback_threadsafe(
hass.loop, async_track_point_in_utc_time, hass,
sunrise_automation_listener, next_rise()).result()
def remove_listener():
"""Remove sunrise listener."""
remove()
"""Remove sunset listener."""
run_callback_threadsafe(hass.loop, remove).result()
return remove_listener
@ -149,19 +180,21 @@ def track_sunset(hass, action, offset=None):
return next_time
@asyncio.coroutine
def sunset_automation_listener(now):
"""Called when it's time for action."""
nonlocal remove
remove = track_point_in_utc_time(hass, sunset_automation_listener,
next_set())
action()
remove = async_track_point_in_utc_time(
hass, sunset_automation_listener, next_set())
hass.async_add_job(action)
remove = track_point_in_utc_time(hass, sunset_automation_listener,
next_set())
remove = run_callback_threadsafe(
hass.loop, async_track_point_in_utc_time, hass,
sunset_automation_listener, next_set()).result()
def remove_listener():
"""Remove sunset listener."""
remove()
run_callback_threadsafe(hass.loop, remove).result()
return remove_listener

View File

@ -149,8 +149,8 @@ class Template(object):
global_vars = ENV.make_globals({
'closest': location_methods.closest,
'distance': location_methods.distance,
'is_state': self.hass.states.async_is_state,
'is_state_attr': self.hass.states.async_is_state_attr,
'is_state': self.hass.states.is_state,
'is_state_attr': self.hass.states.is_state_attr,
'states': AllStates(self.hass),
})

View File

@ -77,7 +77,8 @@ class TestComponentsCore(unittest.TestCase):
service_call = ha.ServiceCall('homeassistant', 'turn_on', {
'entity_id': ['light.test', 'sensor.bla', 'light.bla']
})
self.hass.services._services['homeassistant']['turn_on'](service_call)
service = self.hass.services._services['homeassistant']['turn_on']
service.func(service_call)
self.assertEqual(2, mock_call.call_count)
self.assertEqual(

View File

@ -1,7 +1,8 @@
"""The tests for the logbook component."""
# pylint: disable=protected-access,too-many-public-methods
import unittest
from datetime import timedelta
import unittest
from unittest.mock import patch
from homeassistant.components import sun
import homeassistant.core as ha
@ -18,13 +19,17 @@ from tests.common import mock_http_component, get_test_home_assistant
class TestComponentLogbook(unittest.TestCase):
"""Test the History component."""
EMPTY_CONFIG = logbook.CONFIG_SCHEMA({ha.DOMAIN: {}, logbook.DOMAIN: {}})
EMPTY_CONFIG = logbook.CONFIG_SCHEMA({logbook.DOMAIN: {}})
def setUp(self):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_http_component(self.hass)
assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG)
self.hass.config.components += ['frontend', 'recorder', 'api']
with patch('homeassistant.components.logbook.'
'register_built_in_panel'):
assert setup_component(self.hass, logbook.DOMAIN,
self.EMPTY_CONFIG)
def tearDown(self):
"""Stop everything that was started."""
@ -44,7 +49,6 @@ class TestComponentLogbook(unittest.TestCase):
logbook.ATTR_DOMAIN: 'switch',
logbook.ATTR_ENTITY_ID: 'switch.test_switch'
}, True)
self.hass.block_till_done()
self.assertEqual(1, len(calls))
last_call = calls[-1]
@ -65,7 +69,6 @@ class TestComponentLogbook(unittest.TestCase):
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
self.hass.services.call(logbook.DOMAIN, 'log', {}, True)
self.hass.block_till_done()
self.assertEqual(0, len(calls))

View File

@ -1,6 +1,9 @@
"""Test the entity helper."""
# pylint: disable=protected-access,too-many-public-methods
import unittest
import asyncio
from unittest.mock import MagicMock
import pytest
import homeassistant.helpers.entity as entity
from homeassistant.const import ATTR_HIDDEN
@ -8,26 +11,75 @@ from homeassistant.const import ATTR_HIDDEN
from tests.common import get_test_home_assistant
class TestHelpersEntity(unittest.TestCase):
def test_generate_entity_id_requires_hass_or_ids():
"""Ensure we require at least hass or current ids."""
fmt = 'test.{}'
with pytest.raises(ValueError):
entity.generate_entity_id(fmt, 'hello world')
def test_generate_entity_id_given_keys():
"""Test generating an entity id given current ids."""
fmt = 'test.{}'
assert entity.generate_entity_id(
fmt, 'overwrite hidden true', current_ids=[
'test.overwrite_hidden_true']) == 'test.overwrite_hidden_true_2'
assert entity.generate_entity_id(
fmt, 'overwrite hidden true', current_ids=[
'test.another_entity']) == 'test.overwrite_hidden_true'
def test_async_update_support(event_loop):
"""Test async update getting called."""
sync_update = []
async_update = []
class AsyncEntity(entity.Entity):
hass = MagicMock()
entity_id = 'sensor.test'
def update(self):
sync_update.append([1])
ent = AsyncEntity()
ent.hass.loop = event_loop
@asyncio.coroutine
def test():
yield from ent.async_update_ha_state(True)
event_loop.run_until_complete(test())
assert len(sync_update) == 1
assert len(async_update) == 0
ent.async_update = lambda: async_update.append(1)
event_loop.run_until_complete(test())
assert len(sync_update) == 1
assert len(async_update) == 1
class TestHelpersEntity(object):
"""Test homeassistant.helpers.entity module."""
def setUp(self): # pylint: disable=invalid-name
def setup_method(self, method):
"""Setup things to be run when tests are started."""
self.entity = entity.Entity()
self.entity.entity_id = 'test.overwrite_hidden_true'
self.hass = self.entity.hass = get_test_home_assistant()
self.entity.update_ha_state()
def tearDown(self): # pylint: disable=invalid-name
def teardown_method(self, method):
"""Stop everything that was started."""
self.hass.stop()
entity.set_customize({})
self.hass.stop()
def test_default_hidden_not_in_attributes(self):
"""Test that the default hidden property is set to False."""
self.assertNotIn(
ATTR_HIDDEN,
self.hass.states.get(self.entity.entity_id).attributes)
assert ATTR_HIDDEN not in self.hass.states.get(
self.entity.entity_id).attributes
def test_overwriting_hidden_property_to_true(self):
"""Test we can overwrite hidden property to True."""
@ -35,31 +87,11 @@ class TestHelpersEntity(unittest.TestCase):
self.entity.update_ha_state()
state = self.hass.states.get(self.entity.entity_id)
self.assertTrue(state.attributes.get(ATTR_HIDDEN))
def test_generate_entity_id_requires_hass_or_ids(self):
"""Ensure we require at least hass or current ids."""
fmt = 'test.{}'
with self.assertRaises(ValueError):
entity.generate_entity_id(fmt, 'hello world')
assert state.attributes.get(ATTR_HIDDEN)
def test_generate_entity_id_given_hass(self):
"""Test generating an entity id given hass object."""
fmt = 'test.{}'
self.assertEqual(
'test.overwrite_hidden_true_2',
entity.generate_entity_id(fmt, 'overwrite hidden true',
hass=self.hass))
def test_generate_entity_id_given_keys(self):
"""Test generating an entity id given current ids."""
fmt = 'test.{}'
self.assertEqual(
'test.overwrite_hidden_true_2',
entity.generate_entity_id(
fmt, 'overwrite hidden true',
current_ids=['test.overwrite_hidden_true']))
self.assertEqual(
'test.overwrite_hidden_true',
entity.generate_entity_id(fmt, 'overwrite hidden true',
current_ids=['test.another_entity']))
assert entity.generate_entity_id(
fmt, 'overwrite hidden true',
hass=self.hass) == 'test.overwrite_hidden_true_2'

View File

@ -1,6 +1,7 @@
"""Test to verify that Home Assistant core works."""
# pylint: disable=protected-access,too-many-public-methods
# pylint: disable=too-few-public-methods
import asyncio
import os
import signal
import unittest
@ -362,7 +363,6 @@ class TestServiceRegistry(unittest.TestCase):
self.hass = get_test_home_assistant()
self.services = self.hass.services
self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None)
self.hass.block_till_done()
def tearDown(self): # pylint: disable=invalid-name
"""Stop down stuff we started."""
@ -387,8 +387,13 @@ class TestServiceRegistry(unittest.TestCase):
def test_call_with_blocking_done_in_time(self):
"""Test call with blocking."""
calls = []
def service_handler(call):
"""Service handler."""
calls.append(call)
self.services.register("test_domain", "register_calls",
lambda x: calls.append(1))
service_handler)
self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
@ -404,6 +409,22 @@ class TestServiceRegistry(unittest.TestCase):
finally:
ha.SERVICE_CALL_LIMIT = prior
def test_async_service(self):
"""Test registering and calling an async service."""
calls = []
@asyncio.coroutine
def service_handler(call):
"""Service handler coroutine."""
calls.append(call)
self.services.register('test_domain', 'register_calls',
service_handler)
self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
self.hass.block_till_done()
self.assertEqual(1, len(calls))
class TestConfig(unittest.TestCase):
"""Test configuration methods."""