1
mirror of https://github.com/home-assistant/core synced 2024-07-30 21:18:57 +02:00

Fix service annotations (#31402)

* Fix service annotations

* Filter area_id from service data

* Fix services not accepting entities

* Typo
This commit is contained in:
Paulus Schoutsen 2020-02-02 15:36:39 -08:00 committed by GitHub
parent 81dbdc6b9c
commit 7687ac8b91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 78 additions and 29 deletions

View File

@ -143,11 +143,15 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
)
component.async_register_entity_service(
SERVICE_SELECT_NEXT, {}, lambda entity, call: entity.async_offset_index(1)
SERVICE_SELECT_NEXT,
{},
callback(lambda entity, call: entity.async_offset_index(1)),
)
component.async_register_entity_service(
SERVICE_SELECT_PREVIOUS, {}, lambda entity, call: entity.async_offset_index(-1)
SERVICE_SELECT_PREVIOUS,
{},
callback(lambda entity, call: entity.async_offset_index(-1)),
)
component.async_register_entity_service(
@ -248,7 +252,8 @@ class InputSelect(RestoreEntity):
"""Return unique id for the entity."""
return self._config[CONF_ID]
async def async_select_option(self, option):
@callback
def async_select_option(self, option):
"""Select new option."""
if option not in self._options:
_LOGGER.warning(
@ -260,14 +265,16 @@ class InputSelect(RestoreEntity):
self._current_option = option
self.async_write_ha_state()
async def async_offset_index(self, offset):
@callback
def async_offset_index(self, offset):
"""Offset current index."""
current_index = self._options.index(self._current_option)
new_index = (current_index + offset) % len(self._options)
self._current_option = self._options[new_index]
self.async_write_ha_state()
async def async_set_options(self, options):
@callback
def async_set_options(self, options):
"""Set options."""
self._current_option = options[0]
self._config[CONF_OPTIONS] = options

View File

@ -173,6 +173,23 @@ SCHEMA_WEBSOCKET_GET_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.exten
)
def _rename_keys(**keys):
"""Create validator that renames keys.
Necessary because the service schema names do not match the command parameters.
Async friendly.
"""
def rename(value):
for to_key, from_key in keys.items():
if from_key in value:
value[to_key] = value.pop(from_key)
return value
return rename
async def async_setup(hass, config):
"""Track states and offer events for media_players."""
component = hass.data[DOMAIN] = EntityComponent(
@ -238,30 +255,39 @@ async def async_setup(hass, config):
)
component.async_register_entity_service(
SERVICE_VOLUME_SET,
{vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float},
lambda entity, call: entity.async_set_volume_level(
volume=call.data[ATTR_MEDIA_VOLUME_LEVEL]
vol.All(
cv.make_entity_service_schema(
{vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float}
),
_rename_keys(volume=ATTR_MEDIA_VOLUME_LEVEL),
),
"async_set_volume_level",
[SUPPORT_VOLUME_SET],
)
component.async_register_entity_service(
SERVICE_VOLUME_MUTE,
{vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean},
lambda entity, call: entity.async_mute_volume(
mute=call.data[ATTR_MEDIA_VOLUME_MUTED]
vol.All(
cv.make_entity_service_schema(
{vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean}
),
_rename_keys(mute=ATTR_MEDIA_VOLUME_MUTED),
),
"async_mute_volume",
[SUPPORT_VOLUME_MUTE],
)
component.async_register_entity_service(
SERVICE_MEDIA_SEEK,
{
vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All(
vol.Coerce(float), vol.Range(min=0)
)
},
lambda entity, call: entity.async_media_seek(
position=call.data[ATTR_MEDIA_SEEK_POSITION]
vol.All(
cv.make_entity_service_schema(
{
vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All(
vol.Coerce(float), vol.Range(min=0)
)
}
),
_rename_keys(position=ATTR_MEDIA_SEEK_POSITION),
),
"async_media_seek",
[SUPPORT_SEEK],
)
component.async_register_entity_service(
@ -278,12 +304,15 @@ async def async_setup(hass, config):
)
component.async_register_entity_service(
SERVICE_PLAY_MEDIA,
MEDIA_PLAYER_PLAY_MEDIA_SCHEMA,
lambda entity, call: entity.async_play_media(
media_type=call.data[ATTR_MEDIA_CONTENT_TYPE],
media_id=call.data[ATTR_MEDIA_CONTENT_ID],
enqueue=call.data.get(ATTR_MEDIA_ENQUEUE),
vol.All(
cv.make_entity_service_schema(MEDIA_PLAYER_PLAY_MEDIA_SCHEMA),
_rename_keys(
media_type=ATTR_MEDIA_CONTENT_TYPE,
media_id=ATTR_MEDIA_CONTENT_ID,
enqueue=ATTR_MEDIA_ENQUEUE,
),
),
"async_play_media",
[SUPPORT_PLAY_MEDIA],
)
component.async_register_entity_service(

View File

@ -724,6 +724,8 @@ PLATFORM_SCHEMA = vol.Schema(
PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA)
ENTITY_SERVICE_FIELDS = (ATTR_ENTITY_ID, ATTR_AREA_ID)
def make_entity_service_schema(
schema: dict, *, extra: int = vol.PREVENT_EXTRA
@ -738,7 +740,7 @@ def make_entity_service_schema(
},
extra=extra,
),
has_at_least_one_key(ATTR_ENTITY_ID, ATTR_AREA_ID),
has_at_least_one_key(*ENTITY_SERVICE_FIELDS),
)

View File

@ -283,7 +283,11 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
# If the service function is a string, we'll pass it the service call data
if isinstance(func, str):
data = {key: val for key, val in call.data.items() if key != ATTR_ENTITY_ID}
data = {
key: val
for key, val in call.data.items()
if key not in cv.ENTITY_SERVICE_FIELDS
}
# If the service function is not a string, we pass the service call
else:
data = call
@ -323,6 +327,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
for platform in platforms:
platform_entities = []
for entity in platform.entities.values():
if entity.entity_id not in entity_ids:
continue
@ -380,7 +385,7 @@ async def _handle_service_platform_call(
if asyncio.iscoroutine(result):
_LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.",
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
func,
entity.entity_id,
)

View File

@ -320,14 +320,20 @@ async def test_call_with_sync_func(hass, mock_entities):
async def test_call_with_sync_attr(hass, mock_entities):
"""Test invoking sync service calls."""
mock_entities["light.kitchen"].sync_method = Mock()
mock_method = mock_entities["light.kitchen"].sync_method = Mock()
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
"sync_method",
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
ha.ServiceCall(
"test_domain",
"test_service",
{"entity_id": "light.kitchen", "area_id": "abcd"},
),
)
assert mock_entities["light.kitchen"].sync_method.call_count == 1
assert mock_method.call_count == 1
# We pass empty kwargs because both entity_id and area_id are filtered out
assert mock_method.mock_calls[0][2] == {}
async def test_call_context_user_not_exist(hass):