mirror of
https://github.com/home-assistant/core
synced 2024-07-30 21:18:57 +02:00
Fix service helper not handling sync methods (#31254)
* Fix service helper not handling sync methods * Add legacy support for returning coroutine objects * Fix tests * Fix tests * Convert demo cover to async
This commit is contained in:
parent
111fc1fa8e
commit
01dad31adc
@ -6,7 +6,8 @@ from homeassistant.components.cover import (
|
||||
SUPPORT_OPEN,
|
||||
CoverDevice,
|
||||
)
|
||||
from homeassistant.helpers.event import track_utc_time_change
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.event import async_track_utc_time_change
|
||||
|
||||
from . import DOMAIN
|
||||
|
||||
@ -131,21 +132,21 @@ class DemoCover(CoverDevice):
|
||||
return self._supported_features
|
||||
return super().supported_features
|
||||
|
||||
def close_cover(self, **kwargs):
|
||||
async def async_close_cover(self, **kwargs):
|
||||
"""Close the cover."""
|
||||
if self._position == 0:
|
||||
return
|
||||
if self._position is None:
|
||||
self._closed = True
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
self._is_closing = True
|
||||
self._listen_cover()
|
||||
self._requested_closing = True
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
||||
def close_cover_tilt(self, **kwargs):
|
||||
async def async_close_cover_tilt(self, **kwargs):
|
||||
"""Close the cover tilt."""
|
||||
if self._tilt_position in (0, None):
|
||||
return
|
||||
@ -153,21 +154,21 @@ class DemoCover(CoverDevice):
|
||||
self._listen_cover_tilt()
|
||||
self._requested_closing_tilt = True
|
||||
|
||||
def open_cover(self, **kwargs):
|
||||
async def async_open_cover(self, **kwargs):
|
||||
"""Open the cover."""
|
||||
if self._position == 100:
|
||||
return
|
||||
if self._position is None:
|
||||
self._closed = False
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
self._is_opening = True
|
||||
self._listen_cover()
|
||||
self._requested_closing = False
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
||||
def open_cover_tilt(self, **kwargs):
|
||||
async def async_open_cover_tilt(self, **kwargs):
|
||||
"""Open the cover tilt."""
|
||||
if self._tilt_position in (100, None):
|
||||
return
|
||||
@ -175,7 +176,7 @@ class DemoCover(CoverDevice):
|
||||
self._listen_cover_tilt()
|
||||
self._requested_closing_tilt = False
|
||||
|
||||
def set_cover_position(self, **kwargs):
|
||||
async def async_set_cover_position(self, **kwargs):
|
||||
"""Move the cover to a specific position."""
|
||||
position = kwargs.get(ATTR_POSITION)
|
||||
self._set_position = round(position, -1)
|
||||
@ -185,7 +186,7 @@ class DemoCover(CoverDevice):
|
||||
self._listen_cover()
|
||||
self._requested_closing = position < self._position
|
||||
|
||||
def set_cover_tilt_position(self, **kwargs):
|
||||
async def async_set_cover_tilt_position(self, **kwargs):
|
||||
"""Move the cover til to a specific position."""
|
||||
tilt_position = kwargs.get(ATTR_TILT_POSITION)
|
||||
self._set_tilt_position = round(tilt_position, -1)
|
||||
@ -195,7 +196,7 @@ class DemoCover(CoverDevice):
|
||||
self._listen_cover_tilt()
|
||||
self._requested_closing_tilt = tilt_position < self._tilt_position
|
||||
|
||||
def stop_cover(self, **kwargs):
|
||||
async def async_stop_cover(self, **kwargs):
|
||||
"""Stop the cover."""
|
||||
self._is_closing = False
|
||||
self._is_opening = False
|
||||
@ -206,7 +207,7 @@ class DemoCover(CoverDevice):
|
||||
self._unsub_listener_cover = None
|
||||
self._set_position = None
|
||||
|
||||
def stop_cover_tilt(self, **kwargs):
|
||||
async def async_stop_cover_tilt(self, **kwargs):
|
||||
"""Stop the cover tilt."""
|
||||
if self._tilt_position is None:
|
||||
return
|
||||
@ -216,14 +217,15 @@ class DemoCover(CoverDevice):
|
||||
self._unsub_listener_cover_tilt = None
|
||||
self._set_tilt_position = None
|
||||
|
||||
@callback
|
||||
def _listen_cover(self):
|
||||
"""Listen for changes in cover."""
|
||||
if self._unsub_listener_cover is None:
|
||||
self._unsub_listener_cover = track_utc_time_change(
|
||||
self._unsub_listener_cover = async_track_utc_time_change(
|
||||
self.hass, self._time_changed_cover
|
||||
)
|
||||
|
||||
def _time_changed_cover(self, now):
|
||||
async def _time_changed_cover(self, now):
|
||||
"""Track time changes."""
|
||||
if self._requested_closing:
|
||||
self._position -= 10
|
||||
@ -231,20 +233,20 @@ class DemoCover(CoverDevice):
|
||||
self._position += 10
|
||||
|
||||
if self._position in (100, 0, self._set_position):
|
||||
self.stop_cover()
|
||||
await self.async_stop_cover()
|
||||
|
||||
self._closed = self.current_cover_position <= 0
|
||||
self.async_write_ha_state()
|
||||
|
||||
self.schedule_update_ha_state()
|
||||
|
||||
@callback
|
||||
def _listen_cover_tilt(self):
|
||||
"""Listen for changes in cover tilt."""
|
||||
if self._unsub_listener_cover_tilt is None:
|
||||
self._unsub_listener_cover_tilt = track_utc_time_change(
|
||||
self._unsub_listener_cover_tilt = async_track_utc_time_change(
|
||||
self.hass, self._time_changed_cover_tilt
|
||||
)
|
||||
|
||||
def _time_changed_cover_tilt(self, now):
|
||||
async def _time_changed_cover_tilt(self, now):
|
||||
"""Track time changes."""
|
||||
if self._requested_closing_tilt:
|
||||
self._tilt_position -= 10
|
||||
@ -252,6 +254,6 @@ class DemoCover(CoverDevice):
|
||||
self._tilt_position += 10
|
||||
|
||||
if self._tilt_position in (100, 0, self._set_tilt_position):
|
||||
self.stop_cover_tilt()
|
||||
await self.async_stop_cover_tilt()
|
||||
|
||||
self.schedule_update_ha_state()
|
||||
self.async_write_ha_state()
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Service calling related helpers."""
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from functools import partial, wraps
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
@ -339,7 +339,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
||||
|
||||
tasks = [
|
||||
_handle_service_platform_call(
|
||||
func, data, entities, call.context, required_features
|
||||
hass, func, data, entities, call.context, required_features
|
||||
)
|
||||
for platform, entities in zip(platforms, platforms_entities)
|
||||
]
|
||||
@ -352,7 +352,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
|
||||
|
||||
|
||||
async def _handle_service_platform_call(
|
||||
func, data, entities, context, required_features
|
||||
hass, func, data, entities, context, required_features
|
||||
):
|
||||
"""Handle a function call."""
|
||||
tasks = []
|
||||
@ -370,9 +370,17 @@ async def _handle_service_platform_call(
|
||||
entity.async_set_context(context)
|
||||
|
||||
if isinstance(func, str):
|
||||
await getattr(entity, func)(**data)
|
||||
result = await hass.async_add_job(partial(getattr(entity, func), **data))
|
||||
else:
|
||||
await func(entity, data)
|
||||
result = await hass.async_add_job(func, entity, data)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
_LOGGER.error(
|
||||
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.",
|
||||
func,
|
||||
entity.entity_id,
|
||||
)
|
||||
await result
|
||||
|
||||
if entity.should_poll:
|
||||
tasks.append(entity.async_update_ha_state(True))
|
||||
|
@ -4,7 +4,6 @@ Test setup of RFLink lights component/platform. State tracking and
|
||||
control of RFLink switch devices.
|
||||
|
||||
"""
|
||||
|
||||
from homeassistant.components.light import ATTR_BRIGHTNESS
|
||||
from homeassistant.components.rflink import EVENT_BUTTON_PRESSED
|
||||
from homeassistant.const import (
|
||||
@ -267,15 +266,11 @@ async def test_signal_repetitions_alternation(hass, monkeypatch):
|
||||
# setup mocking rflink module
|
||||
_, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch)
|
||||
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
|
||||
)
|
||||
await hass.services.async_call(
|
||||
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
|
||||
)
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test1"}
|
||||
)
|
||||
await hass.services.async_call(
|
||||
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test1"}
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
@ -299,10 +294,8 @@ async def test_signal_repetitions_cancelling(hass, monkeypatch):
|
||||
# setup mocking rflink module
|
||||
_, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch)
|
||||
|
||||
hass.async_create_task(
|
||||
hass.services.async_call(
|
||||
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
|
||||
)
|
||||
await hass.services.async_call(
|
||||
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
|
||||
)
|
||||
|
||||
hass.async_create_task(
|
||||
|
@ -306,6 +306,30 @@ async def test_call_with_required_features(hass, mock_entities):
|
||||
assert test_service_mock.call_count == 1
|
||||
|
||||
|
||||
async def test_call_with_sync_func(hass, mock_entities):
|
||||
"""Test invoking sync service calls."""
|
||||
test_service_mock = Mock()
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
test_service_mock,
|
||||
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
|
||||
)
|
||||
assert test_service_mock.call_count == 1
|
||||
|
||||
|
||||
async def test_call_with_sync_attr(hass, mock_entities):
|
||||
"""Test invoking sync service calls."""
|
||||
mock_entities["light.kitchen"].sync_method = Mock()
|
||||
await service.entity_service_call(
|
||||
hass,
|
||||
[Mock(entities=mock_entities)],
|
||||
"sync_method",
|
||||
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
|
||||
)
|
||||
assert mock_entities["light.kitchen"].sync_method.call_count == 1
|
||||
|
||||
|
||||
async def test_call_context_user_not_exist(hass):
|
||||
"""Check we don't allow deleted users to do things."""
|
||||
with pytest.raises(exceptions.UnknownUser) as err:
|
||||
@ -348,7 +372,7 @@ async def test_call_context_target_all(hass, mock_service_platform_call, mock_en
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == [mock_entities["light.kitchen"]]
|
||||
|
||||
|
||||
@ -379,7 +403,7 @@ async def test_call_context_target_specific(
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == [mock_entities["light.kitchen"]]
|
||||
|
||||
|
||||
@ -422,7 +446,7 @@ async def test_call_no_context_target_all(
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == list(mock_entities.values())
|
||||
|
||||
|
||||
@ -442,7 +466,7 @@ async def test_call_no_context_target_specific(
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == [mock_entities["light.kitchen"]]
|
||||
|
||||
|
||||
@ -458,7 +482,7 @@ async def test_call_with_match_all(
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == [
|
||||
mock_entities["light.kitchen"],
|
||||
mock_entities["light.living_room"],
|
||||
@ -480,7 +504,7 @@ async def test_call_with_omit_entity_id(
|
||||
)
|
||||
|
||||
assert len(mock_service_platform_call.mock_calls) == 1
|
||||
entities = mock_service_platform_call.mock_calls[0][1][2]
|
||||
entities = mock_service_platform_call.mock_calls[0][1][3]
|
||||
assert entities == []
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user