1
mirror of https://github.com/home-assistant/core synced 2024-10-04 07:58:43 +02:00

Address asyncio comments (#3663)

* Template platforms: create_task instead of yield from

* Automation: less yielding, more create_tasking

* Helpers.script: less yielding, more create_tasking

* Deflake logbook test

* Deflake automation reload config test

* MQTT: Use async_add_job and threaded_listener_factory

* Deflake other logbook test

* lint

* Add test for automation trigger service

* MQTT client can be called from within async
This commit is contained in:
Paulus Schoutsen 2016-10-03 22:39:27 -07:00 committed by GitHub
parent f2a12b7ac2
commit d58548dd1c
10 changed files with 123 additions and 76 deletions

View File

@ -154,15 +154,24 @@ def setup(hass, config):
def trigger_service_handler(service_call):
"""Handle automation triggers."""
for entity in component.extract_from_service(service_call):
yield from entity.async_trigger(
service_call.data.get(ATTR_VARIABLES))
hass.loop.create_task(entity.async_trigger(
service_call.data.get(ATTR_VARIABLES), True))
@asyncio.coroutine
def service_handler(service_call):
"""Handle automation service calls."""
def turn_onoff_service_handler(service_call):
"""Handle automation turn on/off service calls."""
method = 'async_{}'.format(service_call.service)
for entity in component.extract_from_service(service_call):
yield from getattr(entity, method)()
hass.loop.create_task(getattr(entity, method)())
@asyncio.coroutine
def toggle_service_handler(service_call):
"""Handle automation toggle service calls."""
for entity in component.extract_from_service(service_call):
if entity.is_on:
hass.loop.create_task(entity.async_turn_off())
else:
hass.loop.create_task(entity.async_turn_on())
@asyncio.coroutine
def reload_service_handler(service_call):
@ -171,7 +180,7 @@ def setup(hass, config):
None, component.prepare_reload)
if conf is None:
return
yield from _async_process_config(hass, conf, component)
hass.loop.create_task(_async_process_config(hass, conf, component))
hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
descriptions.get(SERVICE_TRIGGER),
@ -181,8 +190,12 @@ def setup(hass, config):
descriptions.get(SERVICE_RELOAD),
schema=RELOAD_SERVICE_SCHEMA)
for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE):
hass.services.register(DOMAIN, service, service_handler,
hass.services.register(DOMAIN, SERVICE_TOGGLE, toggle_service_handler,
descriptions.get(SERVICE_TOGGLE),
schema=SERVICE_SCHEMA)
for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF):
hass.services.register(DOMAIN, service, turn_onoff_service_handler,
descriptions.get(service),
schema=SERVICE_SCHEMA)
@ -236,8 +249,11 @@ class AutomationEntity(ToggleEntity):
@asyncio.coroutine
def async_turn_on(self, **kwargs) -> None:
"""Turn the entity on and update the state."""
if self._enabled:
return
yield from self.async_enable()
yield from self.async_update_ha_state()
self.hass.loop.create_task(self.async_update_ha_state())
@asyncio.coroutine
def async_turn_off(self, **kwargs) -> None:
@ -248,23 +264,18 @@ class AutomationEntity(ToggleEntity):
self._async_detach_triggers()
self._async_detach_triggers = None
self._enabled = False
yield from self.async_update_ha_state()
self.hass.loop.create_task(self.async_update_ha_state())
@asyncio.coroutine
def async_toggle(self):
"""Toggle the state of the entity."""
if self._enabled:
yield from self.async_turn_off()
else:
yield from self.async_turn_on()
def async_trigger(self, variables, skip_condition=False):
"""Trigger automation.
@asyncio.coroutine
def async_trigger(self, variables):
"""Trigger automation."""
if self._cond_func(variables):
This method is a coroutine.
"""
if skip_condition or self._cond_func(variables):
yield from self._async_action(variables)
self._last_triggered = utcnow()
yield from self.async_update_ha_state()
self.hass.loop.create_task(self.async_update_ha_state())
def remove(self):
"""Remove automation from HASS."""
@ -274,7 +285,10 @@ class AutomationEntity(ToggleEntity):
@asyncio.coroutine
def async_enable(self):
"""Enable this automation entity."""
"""Enable this automation entity.
This method is a coroutine.
"""
if self._enabled:
return
@ -285,8 +299,12 @@ class AutomationEntity(ToggleEntity):
@asyncio.coroutine
def _async_process_config(hass, config, component):
"""Process config and add automations."""
"""Process config and add automations.
This method is a coroutine.
"""
entities = []
tasks = []
for config_key in extract_domain_configs(config, DOMAIN):
conf = config[config_key]
@ -315,9 +333,10 @@ def _async_process_config(hass, config, component):
config_block.get(CONF_TRIGGER, []), name)
entity = AutomationEntity(name, async_attach_triggers, cond_func,
action, hidden)
yield from entity.async_enable()
tasks.append(hass.loop.create_task(entity.async_enable()))
entities.append(entity)
yield from asyncio.gather(*tasks, loop=hass.loop)
yield from hass.loop.run_in_executor(
None, component.add_entities, entities)
@ -333,7 +352,7 @@ def _async_get_action(hass, config, name):
"""Action to be executed."""
_LOGGER.info('Executing %s', name)
logbook.async_log_entry(hass, name, 'has been triggered', DOMAIN)
yield from script_obj.async_run(variables)
hass.loop.create_task(script_obj.async_run(variables))
return action
@ -359,7 +378,10 @@ def _async_process_if(hass, config, p_config):
@asyncio.coroutine
def _async_process_trigger(hass, config, trigger_configs, name, action):
"""Setup the triggers."""
"""Setup the triggers.
This method is a coroutine.
"""
removes = []
for conf in trigger_configs:

View File

@ -85,7 +85,7 @@ class BinarySensorTemplate(BinarySensorDevice):
@asyncio.coroutine
def template_bsensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
yield from self.async_update_ha_state(True)
hass.loop.create_task(self.async_update_ha_state(True))
track_state_change(hass, entity_ids, template_bsensor_state_listener)

View File

@ -12,16 +12,14 @@ import time
import voluptuous as vol
from homeassistant.core import JobPriority
from homeassistant.bootstrap import prepare_setup_platform
from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import template
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers import template, config_validation as cv
from homeassistant.helpers.event import threaded_listener_factory
from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
CONF_PLATFORM, CONF_SCAN_INTERVAL, CONF_VALUE_TEMPLATE)
from homeassistant.util.async import run_callback_threadsafe
_LOGGER = logging.getLogger(__name__)
@ -165,18 +163,6 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
hass.services.call(DOMAIN, SERVICE_PUBLISH, data)
def subscribe(hass, topic, callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic."""
async_remove = run_callback_threadsafe(
hass.loop, async_subscribe, hass, topic, callback, qos).result()
def remove_mqtt():
"""Remove MQTT subscription."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove_mqtt
def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic."""
@asyncio.coroutine
@ -185,14 +171,8 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
if not _match_topic(topic, event.data[ATTR_TOPIC]):
return
if asyncio.iscoroutinefunction(callback):
yield from callback(
event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD],
event.data[ATTR_QOS])
else:
hass.add_job(callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS],
priority=JobPriority.EVENT_CALLBACK)
hass.async_add_job(callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED,
mqtt_topic_subscriber)
@ -203,6 +183,10 @@ def async_subscribe(hass, topic, callback, qos=DEFAULT_QOS):
return async_remove
# pylint: disable=invalid-name
subscribe = threaded_listener_factory(async_subscribe)
def _setup_server(hass, config):
"""Try to start embedded MQTT broker."""
conf = config.get(DOMAIN, {})

View File

@ -124,7 +124,7 @@ class ScriptEntity(ToggleEntity):
def __init__(self, hass, object_id, name, sequence):
"""Initialize the script."""
self.entity_id = ENTITY_ID_FORMAT.format(object_id)
self.script = Script(hass, sequence, name, self.update_ha_state)
self.script = Script(hass, sequence, name, self.async_update_ha_state)
@property
def should_poll(self):

View File

@ -82,7 +82,7 @@ class SensorTemplate(Entity):
@asyncio.coroutine
def template_sensor_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
yield from self.async_update_ha_state(True)
hass.loop.create_task(self.async_update_ha_state(True))
track_state_change(hass, entity_ids, template_sensor_state_listener)

View File

@ -91,7 +91,7 @@ class SwitchTemplate(SwitchDevice):
@asyncio.coroutine
def template_switch_state_listener(entity, old_state, new_state):
"""Called when the target device changes state."""
yield from self.async_update_ha_state(True)
hass.loop.create_task(self.async_update_ha_state(True))
track_state_change(hass, entity_ids, template_switch_state_listener)

View File

@ -9,11 +9,11 @@ from ..const import (
from ..util import dt as dt_util
from ..util.async import run_callback_threadsafe
# PyLint does not like the use of _threaded_factory
# PyLint does not like the use of threaded_listener_factory
# pylint: disable=invalid-name
def _threaded_factory(async_factory):
def threaded_listener_factory(async_factory):
"""Convert an async event helper to a threaded one."""
@ft.wraps(async_factory)
def factory(*args, **kwargs):
@ -83,7 +83,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener)
track_state_change = _threaded_factory(async_track_state_change)
track_state_change = threaded_listener_factory(async_track_state_change)
def async_track_point_in_time(hass, action, point_in_time):
@ -100,7 +100,7 @@ def async_track_point_in_time(hass, action, point_in_time):
utc_point_in_time)
track_point_in_time = _threaded_factory(async_track_point_in_time)
track_point_in_time = threaded_listener_factory(async_track_point_in_time)
def async_track_point_in_utc_time(hass, action, point_in_time):
@ -133,7 +133,8 @@ def async_track_point_in_utc_time(hass, action, point_in_time):
return async_unsub
track_point_in_utc_time = _threaded_factory(async_track_point_in_utc_time)
track_point_in_utc_time = threaded_listener_factory(
async_track_point_in_utc_time)
def async_track_sunrise(hass, action, offset=None):
@ -169,7 +170,7 @@ def async_track_sunrise(hass, action, offset=None):
return remove_listener
track_sunrise = _threaded_factory(async_track_sunrise)
track_sunrise = threaded_listener_factory(async_track_sunrise)
def async_track_sunset(hass, action, offset=None):
@ -205,7 +206,7 @@ def async_track_sunset(hass, action, offset=None):
return remove_listener
track_sunset = _threaded_factory(async_track_sunset)
track_sunset = threaded_listener_factory(async_track_sunset)
# pylint: disable=too-many-arguments
@ -251,7 +252,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
pattern_time_change_listener)
track_utc_time_change = _threaded_factory(async_track_utc_time_change)
track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
# pylint: disable=too-many-arguments
@ -262,7 +263,7 @@ def async_track_time_change(hass, action, year=None, month=None, day=None,
minute, second, local=True)
track_time_change = _threaded_factory(async_track_time_change)
track_time_change = threaded_listener_factory(async_track_time_change)
def _process_state_match(parameter):

View File

@ -66,7 +66,7 @@ class Script():
def async_run(self, variables: Optional[Sequence]=None) -> None:
"""Run script.
Returns a coroutine.
This method is a coroutine.
"""
if self._cur == -1:
self._log('Running script')
@ -85,7 +85,7 @@ class Script():
def script_delay(now):
"""Called after delay is done."""
self._async_unsub_delay_listener = None
yield from self.async_run(variables)
self.hass.loop.create_task(self.async_run(variables))
delay = action[CONF_DELAY]
@ -100,7 +100,8 @@ class Script():
self.hass, script_delay,
date_util.utcnow() + delay)
self._cur = cur + 1
self._trigger_change_listener()
if self._change_listener:
self.hass.async_add_job(self._change_listener)
return
elif CONF_CONDITION in action:
@ -115,7 +116,8 @@ class Script():
self._cur = -1
self.last_action = None
self._trigger_change_listener()
if self._change_listener:
self.hass.async_add_job(self._change_listener)
def stop(self) -> None:
"""Stop running script."""
@ -128,11 +130,15 @@ class Script():
self._cur = -1
self._async_remove_listener()
self._trigger_change_listener()
if self._change_listener:
self.hass.async_add_job(self._change_listener)
@asyncio.coroutine
def _async_call_service(self, action, variables):
"""Call the service specified in the action."""
"""Call the service specified in the action.
This method is a coroutine.
"""
self.last_action = action.get(CONF_ALIAS, 'call service')
self._log("Executing step %s" % self.last_action)
yield from service.async_call_from_config(
@ -165,10 +171,3 @@ class Script():
msg = "Script {}: {}".format(self.name, msg)
_LOGGER.info(msg)
def _trigger_change_listener(self):
"""Trigger the change listener."""
if not self._change_listener:
return
self.hass.async_add_job(self._change_listener)

View File

@ -144,6 +144,35 @@ class TestAutomation(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(2, len(self.calls))
def test_trigger_service_ignoring_condition(self):
"""Test triggers."""
assert setup_component(self.hass, automation.DOMAIN, {
automation.DOMAIN: {
'trigger': [
{
'platform': 'event',
'event_type': 'test_event',
},
],
'condition': {
'condition': 'state',
'entity_id': 'non.existing',
'state': 'beer',
},
'action': {
'service': 'test.automation',
}
}
})
self.hass.bus.fire('test_event')
self.hass.block_till_done()
assert len(self.calls) == 0
self.hass.services.call('automation', 'trigger', blocking=True)
self.hass.block_till_done()
assert len(self.calls) == 1
def test_two_conditions_with_and(self):
"""Test two and conditions."""
entity_id = 'test.entity'
@ -348,6 +377,8 @@ class TestAutomation(unittest.TestCase):
automation.reload(self.hass)
self.hass.block_till_done()
# De-flake ?!
self.hass.block_till_done()
assert self.hass.states.get('automation.hello') is None
assert self.hass.states.get('automation.bye') is not None

View File

@ -50,6 +50,11 @@ class TestComponentLogbook(unittest.TestCase):
logbook.ATTR_ENTITY_ID: 'switch.test_switch'
}, True)
# Logbook entry service call results in firing an event.
# Our service call will unblock when the event listeners have been
# scheduled. This means that they may not have been processed yet.
self.hass.block_till_done()
self.assertEqual(1, len(calls))
last_call = calls[-1]
@ -70,6 +75,11 @@ class TestComponentLogbook(unittest.TestCase):
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
self.hass.services.call(logbook.DOMAIN, 'log', {}, True)
# Logbook entry service call results in firing an event.
# Our service call will unblock when the event listeners have been
# scheduled. This means that they may not have been processed yet.
self.hass.block_till_done()
self.assertEqual(0, len(calls))
def test_humanify_filter_sensor(self):