1
mirror of https://github.com/home-assistant/core synced 2024-07-30 21:18:57 +02:00

Fix service helper not handling sync methods (#31254)

* Fix service helper not handling sync methods

* Add legacy support for returning coroutine objects

* Fix tests

* Fix tests

* Convert demo cover to async
This commit is contained in:
Paulus Schoutsen 2020-01-29 16:27:25 -08:00 committed by GitHub
parent 111fc1fa8e
commit 01dad31adc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 46 deletions

View File

@ -6,7 +6,8 @@ from homeassistant.components.cover import (
SUPPORT_OPEN,
CoverDevice,
)
from homeassistant.helpers.event import track_utc_time_change
from homeassistant.core import callback
from homeassistant.helpers.event import async_track_utc_time_change
from . import DOMAIN
@ -131,21 +132,21 @@ class DemoCover(CoverDevice):
return self._supported_features
return super().supported_features
def close_cover(self, **kwargs):
async def async_close_cover(self, **kwargs):
"""Close the cover."""
if self._position == 0:
return
if self._position is None:
self._closed = True
self.schedule_update_ha_state()
self.async_write_ha_state()
return
self._is_closing = True
self._listen_cover()
self._requested_closing = True
self.schedule_update_ha_state()
self.async_write_ha_state()
def close_cover_tilt(self, **kwargs):
async def async_close_cover_tilt(self, **kwargs):
"""Close the cover tilt."""
if self._tilt_position in (0, None):
return
@ -153,21 +154,21 @@ class DemoCover(CoverDevice):
self._listen_cover_tilt()
self._requested_closing_tilt = True
def open_cover(self, **kwargs):
async def async_open_cover(self, **kwargs):
"""Open the cover."""
if self._position == 100:
return
if self._position is None:
self._closed = False
self.schedule_update_ha_state()
self.async_write_ha_state()
return
self._is_opening = True
self._listen_cover()
self._requested_closing = False
self.schedule_update_ha_state()
self.async_write_ha_state()
def open_cover_tilt(self, **kwargs):
async def async_open_cover_tilt(self, **kwargs):
"""Open the cover tilt."""
if self._tilt_position in (100, None):
return
@ -175,7 +176,7 @@ class DemoCover(CoverDevice):
self._listen_cover_tilt()
self._requested_closing_tilt = False
def set_cover_position(self, **kwargs):
async def async_set_cover_position(self, **kwargs):
"""Move the cover to a specific position."""
position = kwargs.get(ATTR_POSITION)
self._set_position = round(position, -1)
@ -185,7 +186,7 @@ class DemoCover(CoverDevice):
self._listen_cover()
self._requested_closing = position < self._position
def set_cover_tilt_position(self, **kwargs):
async def async_set_cover_tilt_position(self, **kwargs):
"""Move the cover til to a specific position."""
tilt_position = kwargs.get(ATTR_TILT_POSITION)
self._set_tilt_position = round(tilt_position, -1)
@ -195,7 +196,7 @@ class DemoCover(CoverDevice):
self._listen_cover_tilt()
self._requested_closing_tilt = tilt_position < self._tilt_position
def stop_cover(self, **kwargs):
async def async_stop_cover(self, **kwargs):
"""Stop the cover."""
self._is_closing = False
self._is_opening = False
@ -206,7 +207,7 @@ class DemoCover(CoverDevice):
self._unsub_listener_cover = None
self._set_position = None
def stop_cover_tilt(self, **kwargs):
async def async_stop_cover_tilt(self, **kwargs):
"""Stop the cover tilt."""
if self._tilt_position is None:
return
@ -216,14 +217,15 @@ class DemoCover(CoverDevice):
self._unsub_listener_cover_tilt = None
self._set_tilt_position = None
@callback
def _listen_cover(self):
"""Listen for changes in cover."""
if self._unsub_listener_cover is None:
self._unsub_listener_cover = track_utc_time_change(
self._unsub_listener_cover = async_track_utc_time_change(
self.hass, self._time_changed_cover
)
def _time_changed_cover(self, now):
async def _time_changed_cover(self, now):
"""Track time changes."""
if self._requested_closing:
self._position -= 10
@ -231,20 +233,20 @@ class DemoCover(CoverDevice):
self._position += 10
if self._position in (100, 0, self._set_position):
self.stop_cover()
await self.async_stop_cover()
self._closed = self.current_cover_position <= 0
self.async_write_ha_state()
self.schedule_update_ha_state()
@callback
def _listen_cover_tilt(self):
"""Listen for changes in cover tilt."""
if self._unsub_listener_cover_tilt is None:
self._unsub_listener_cover_tilt = track_utc_time_change(
self._unsub_listener_cover_tilt = async_track_utc_time_change(
self.hass, self._time_changed_cover_tilt
)
def _time_changed_cover_tilt(self, now):
async def _time_changed_cover_tilt(self, now):
"""Track time changes."""
if self._requested_closing_tilt:
self._tilt_position -= 10
@ -252,6 +254,6 @@ class DemoCover(CoverDevice):
self._tilt_position += 10
if self._tilt_position in (100, 0, self._set_tilt_position):
self.stop_cover_tilt()
await self.async_stop_cover_tilt()
self.schedule_update_ha_state()
self.async_write_ha_state()

View File

@ -1,6 +1,6 @@
"""Service calling related helpers."""
import asyncio
from functools import wraps
from functools import partial, wraps
import logging
from typing import Callable
@ -339,7 +339,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
tasks = [
_handle_service_platform_call(
func, data, entities, call.context, required_features
hass, func, data, entities, call.context, required_features
)
for platform, entities in zip(platforms, platforms_entities)
]
@ -352,7 +352,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
async def _handle_service_platform_call(
func, data, entities, context, required_features
hass, func, data, entities, context, required_features
):
"""Handle a function call."""
tasks = []
@ -370,9 +370,17 @@ async def _handle_service_platform_call(
entity.async_set_context(context)
if isinstance(func, str):
await getattr(entity, func)(**data)
result = await hass.async_add_job(partial(getattr(entity, func), **data))
else:
await func(entity, data)
result = await hass.async_add_job(func, entity, data)
if asyncio.iscoroutine(result):
_LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.",
func,
entity.entity_id,
)
await result
if entity.should_poll:
tasks.append(entity.async_update_ha_state(True))

View File

@ -4,7 +4,6 @@ Test setup of RFLink lights component/platform. State tracking and
control of RFLink switch devices.
"""
from homeassistant.components.light import ATTR_BRIGHTNESS
from homeassistant.components.rflink import EVENT_BUTTON_PRESSED
from homeassistant.const import (
@ -267,15 +266,11 @@ async def test_signal_repetitions_alternation(hass, monkeypatch):
# setup mocking rflink module
_, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch)
hass.async_create_task(
hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
)
await hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
)
hass.async_create_task(
hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test1"}
)
await hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test1"}
)
await hass.async_block_till_done()
@ -299,10 +294,8 @@ async def test_signal_repetitions_cancelling(hass, monkeypatch):
# setup mocking rflink module
_, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch)
hass.async_create_task(
hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
)
await hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
)
hass.async_create_task(

View File

@ -306,6 +306,30 @@ async def test_call_with_required_features(hass, mock_entities):
assert test_service_mock.call_count == 1
async def test_call_with_sync_func(hass, mock_entities):
"""Test invoking sync service calls."""
test_service_mock = Mock()
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
test_service_mock,
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
)
assert test_service_mock.call_count == 1
async def test_call_with_sync_attr(hass, mock_entities):
"""Test invoking sync service calls."""
mock_entities["light.kitchen"].sync_method = Mock()
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
"sync_method",
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
)
assert mock_entities["light.kitchen"].sync_method.call_count == 1
async def test_call_context_user_not_exist(hass):
"""Check we don't allow deleted users to do things."""
with pytest.raises(exceptions.UnknownUser) as err:
@ -348,7 +372,7 @@ async def test_call_context_target_all(hass, mock_service_platform_call, mock_en
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]]
@ -379,7 +403,7 @@ async def test_call_context_target_specific(
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]]
@ -422,7 +446,7 @@ async def test_call_no_context_target_all(
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == list(mock_entities.values())
@ -442,7 +466,7 @@ async def test_call_no_context_target_specific(
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]]
@ -458,7 +482,7 @@ async def test_call_with_match_all(
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [
mock_entities["light.kitchen"],
mock_entities["light.living_room"],
@ -480,7 +504,7 @@ async def test_call_with_omit_entity_id(
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2]
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == []