1
mirror of https://github.com/home-assistant/core synced 2024-07-30 21:18:57 +02:00

Use entity.async_request_call in service helper (#31454)

* Use entity.async_request_call in service helper

* Clean up semaphore handling

* Address comments

* Simplify call entity service helper

* Fix stupid rflink test
This commit is contained in:
Paulus Schoutsen 2020-02-04 15:30:15 -08:00 committed by GitHub
parent 2c439af165
commit e970177eeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 211 additions and 193 deletions

View File

@ -23,6 +23,8 @@ from . import (
_LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
TYPE_STANDARD = "standard"
TYPE_INVERTED = "inverted"

View File

@ -31,6 +31,8 @@ from . import (
_LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
TYPE_DIMMABLE = "dimmable"
TYPE_SWITCHABLE = "switchable"
TYPE_HYBRID = "hybrid"

View File

@ -22,6 +22,8 @@ from . import (
_LOGGER = logging.getLogger(__name__)
PARALLEL_UPDATES = 0
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{
vol.Optional(

View File

@ -568,7 +568,6 @@ class Entity(ABC):
# call an requests
async def async_request_call(self, coro):
"""Process request batched."""
if self.parallel_updates:
await self.parallel_updates.acquire()

View File

@ -62,22 +62,42 @@ class EntityPlatform:
# Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities
if platform is None:
self.parallel_updates = None
self.parallel_updates_semaphore: Optional[asyncio.Semaphore] = None
self.parallel_updates_created = True
self.parallel_updates: Optional[asyncio.Semaphore] = None
return
self.parallel_updates = getattr(platform, "PARALLEL_UPDATES", None)
# semaphore will be created on demand
self.parallel_updates_semaphore = None
self.parallel_updates_created = False
self.parallel_updates = None
def _get_parallel_updates_semaphore(self) -> asyncio.Semaphore:
"""Get or create a semaphore for parallel updates."""
if self.parallel_updates_semaphore is None:
self.parallel_updates_semaphore = asyncio.Semaphore(
self.parallel_updates if self.parallel_updates else 1,
loop=self.hass.loop,
)
return self.parallel_updates_semaphore
@callback
def _get_parallel_updates_semaphore(
self, entity_has_async_update: bool
) -> Optional[asyncio.Semaphore]:
"""Get or create a semaphore for parallel updates.
Semaphore will be created on demand because we base it off if update method is async or not.
If parallel updates is set to 0, we skip the semaphore.
If parallel updates is set to a number, we initialize the semaphore to that number.
Default for entities with `async_update` method is 1. Otherwise it's 0.
"""
if self.parallel_updates_created:
return self.parallel_updates
self.parallel_updates_created = True
parallel_updates = getattr(self.platform, "PARALLEL_UPDATES", None)
if parallel_updates is None and not entity_has_async_update:
parallel_updates = 1
if parallel_updates == 0:
parallel_updates = None
if parallel_updates is not None:
self.parallel_updates = asyncio.Semaphore(parallel_updates)
return self.parallel_updates
async def async_setup(self, platform_config, discovery_info=None):
"""Set up the platform from a config file."""
@ -282,21 +302,9 @@ class EntityPlatform:
entity.hass = self.hass
entity.platform = self
# Async entity
# PARALLEL_UPDATES == None: entity.parallel_updates = None
# PARALLEL_UPDATES == 0: entity.parallel_updates = None
# PARALLEL_UPDATES > 0: entity.parallel_updates = Semaphore(p)
# Sync entity
# PARALLEL_UPDATES == None: entity.parallel_updates = Semaphore(1)
# PARALLEL_UPDATES == 0: entity.parallel_updates = None
# PARALLEL_UPDATES > 0: entity.parallel_updates = Semaphore(p)
if hasattr(entity, "async_update") and not self.parallel_updates:
entity.parallel_updates = None
elif not hasattr(entity, "async_update") and self.parallel_updates == 0:
entity.parallel_updates = None
else:
entity.parallel_updates = self._get_parallel_updates_semaphore()
entity.parallel_updates = self._get_parallel_updates_semaphore(
hasattr(entity, "async_update")
)
# Update properties before we generate the entity_id
if update_before_add:

View File

@ -316,16 +316,15 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
# Check the permissions
# A list with for each platform in platforms a list of entities to call
# the service on.
platforms_entities = []
# A list with entities to call the service on.
entity_candidates = []
if entity_perms is None:
for platform in platforms:
if target_all_entities:
platforms_entities.append(list(platform.entities.values()))
entity_candidates.extend(platform.entities.values())
else:
platforms_entities.append(
entity_candidates.extend(
[
entity
for entity in platform.entities.values()
@ -337,7 +336,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
# If we target all entities, we will select all entities the user
# is allowed to control.
for platform in platforms:
platforms_entities.append(
entity_candidates.extend(
[
entity
for entity in platform.entities.values()
@ -362,39 +361,20 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
platform_entities.append(entity)
platforms_entities.append(platform_entities)
entity_candidates.extend(platform_entities)
if not target_all_entities:
for platform_entities in platforms_entities:
for entity in platform_entities:
entity_ids.remove(entity.entity_id)
for entity in entity_candidates:
entity_ids.remove(entity.entity_id)
if entity_ids:
_LOGGER.warning(
"Unable to find referenced entities %s", ", ".join(sorted(entity_ids))
)
tasks = [
_handle_service_platform_call(
hass, func, data, entities, call.context, required_features
)
for platform, entities in zip(platforms, platforms_entities)
]
entities = []
if tasks:
done, pending = await asyncio.wait(tasks)
assert not pending
for future in done:
future.result() # pop exception if have
async def _handle_service_platform_call(
hass, func, data, entities, context, required_features
):
"""Handle a function call."""
tasks = []
for entity in entities:
for entity in entity_candidates:
if not entity.available:
continue
@ -404,27 +384,33 @@ async def _handle_service_platform_call(
):
continue
entity.async_set_context(context)
entities.append(entity)
if isinstance(func, str):
result = hass.async_add_job(partial(getattr(entity, func), **data))
else:
result = hass.async_add_job(func, entity, data)
if not entities:
return
# Guard because callback functions do not return a task when passed to async_add_job.
if result is not None:
result = await result
if asyncio.iscoroutine(result):
_LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
func,
entity.entity_id,
done, pending = await asyncio.wait(
[
entity.async_request_call(
_handle_entity_call(hass, entity, func, data, call.context)
)
await result
for entity in entities
]
)
assert not pending
for future in done:
future.result() # pop exception if have
if entity.should_poll:
tasks.append(entity.async_update_ha_state(True))
tasks = []
for entity in entities:
if not entity.should_poll:
continue
# Context expires if the turn on commands took a long time.
# Set context again so it's there when we update
entity.async_set_context(call.context)
tasks.append(entity.async_update_ha_state(True))
if tasks:
done, pending = await asyncio.wait(tasks)
@ -433,6 +419,28 @@ async def _handle_service_platform_call(
future.result() # pop exception if have
async def _handle_entity_call(hass, entity, func, data, context):
"""Handle calling service method."""
entity.async_set_context(context)
if isinstance(func, str):
result = hass.async_add_job(partial(getattr(entity, func), **data))
else:
result = hass.async_add_job(func, entity, data)
# Guard because callback functions do not return a task when passed to async_add_job.
if result is not None:
await result
if asyncio.iscoroutine(result):
_LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.",
func,
entity.entity_id,
)
await result
@bind_hass
@ha.callback
def async_register_admin_service(
@ -474,6 +482,7 @@ def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
return await service_handler(call)
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(
context=call.context,
@ -482,14 +491,12 @@ def verify_domain_control(hass: HomeAssistantType, domain: str) -> Callable:
)
reg = await hass.helpers.entity_registry.async_get_registry()
entities = [
entity.entity_id
for entity in reg.entities.values()
if entity.platform == domain
]
for entity_id in entities:
if user.permissions.check_entity(entity_id, POLICY_CONTROL):
for entity in reg.entities.values():
if entity.platform != domain:
continue
if user.permissions.check_entity(entity.entity_id, POLICY_CONTROL):
return await service_handler(call)
raise Unauthorized(

View File

@ -270,8 +270,6 @@ async def test_parallel_updates_async_platform_with_constant(hass):
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates == 2
class AsyncEntity(MockEntity):
"""Mock entity that has async_update."""
@ -296,7 +294,6 @@ async def test_parallel_updates_sync_platform(hass):
await component.async_setup({DOMAIN: {"platform": "platform"}})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is None
class SyncEntity(MockEntity):
"""Mock entity that has update."""
@ -323,7 +320,6 @@ async def test_parallel_updates_sync_platform_with_constant(hass):
await component.async_setup({DOMAIN: {"platform": "platform"}})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates == 2
class SyncEntity(MockEntity):
"""Mock entity that has update."""

View File

@ -39,31 +39,29 @@ from tests.common import (
@pytest.fixture
def mock_service_platform_call():
def mock_handle_entity_call():
"""Mock service platform call."""
with patch(
"homeassistant.helpers.service._handle_service_platform_call",
"homeassistant.helpers.service._handle_entity_call",
side_effect=lambda *args: mock_coro(),
) as mock_call:
yield mock_call
@pytest.fixture
def mock_entities():
def mock_entities(hass):
"""Return mock entities in an ordered dict."""
kitchen = Mock(
kitchen = MockEntity(
entity_id="light.kitchen",
available=True,
should_poll=False,
supported_features=1,
platform="test_domain",
)
living_room = Mock(
living_room = MockEntity(
entity_id="light.living_room",
available=True,
should_poll=False,
supported_features=0,
platform="test_domain",
)
entities = OrderedDict()
entities[kitchen.entity_id] = kitchen
@ -374,7 +372,7 @@ async def test_call_context_user_not_exist(hass):
assert err.value.context.user_id == "non-existing"
async def test_call_context_target_all(hass, mock_service_platform_call, mock_entities):
async def test_call_context_target_all(hass, mock_handle_entity_call, mock_entities):
"""Check we only target allowed entities if targeting all."""
with patch(
"homeassistant.auth.AuthManager.async_get_user",
@ -398,13 +396,12 @@ async def test_call_context_target_all(hass, mock_service_platform_call, mock_en
),
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]]
assert len(mock_handle_entity_call.mock_calls) == 1
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
async def test_call_context_target_specific(
hass, mock_service_platform_call, mock_entities
hass, mock_handle_entity_call, mock_entities
):
"""Check targeting specific entities."""
with patch(
@ -429,13 +426,12 @@ async def test_call_context_target_specific(
),
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]]
assert len(mock_handle_entity_call.mock_calls) == 1
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
async def test_call_context_target_specific_no_auth(
hass, mock_service_platform_call, mock_entities
hass, mock_handle_entity_call, mock_entities
):
"""Check targeting specific entities without auth."""
with pytest.raises(exceptions.Unauthorized) as err:
@ -459,9 +455,7 @@ async def test_call_context_target_specific_no_auth(
assert err.value.entity_id == "light.kitchen"
async def test_call_no_context_target_all(
hass, mock_service_platform_call, mock_entities
):
async def test_call_no_context_target_all(hass, mock_handle_entity_call, mock_entities):
"""Check we target all if no user context given."""
await service.entity_service_call(
hass,
@ -472,13 +466,14 @@ async def test_call_no_context_target_all(
),
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == list(mock_entities.values())
assert len(mock_handle_entity_call.mock_calls) == 2
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
mock_entities.values()
)
async def test_call_no_context_target_specific(
hass, mock_service_platform_call, mock_entities
hass, mock_handle_entity_call, mock_entities
):
"""Check we can target specified entities."""
await service.entity_service_call(
@ -492,13 +487,12 @@ async def test_call_no_context_target_specific(
),
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]]
assert len(mock_handle_entity_call.mock_calls) == 1
assert mock_handle_entity_call.mock_calls[0][1][1].entity_id == "light.kitchen"
async def test_call_with_match_all(
hass, mock_service_platform_call, mock_entities, caplog
hass, mock_handle_entity_call, mock_entities, caplog
):
"""Check we only target allowed entities if targeting all."""
await service.entity_service_call(
@ -508,20 +502,13 @@ async def test_call_with_match_all(
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [
mock_entities["light.kitchen"],
mock_entities["light.living_room"],
]
assert (
"Not passing an entity ID to a service to target all entities is deprecated"
) not in caplog.text
assert len(mock_handle_entity_call.mock_calls) == 2
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
mock_entities.values()
)
async def test_call_with_omit_entity_id(
hass, mock_service_platform_call, mock_entities
):
async def test_call_with_omit_entity_id(hass, mock_handle_entity_call, mock_entities):
"""Check service call if we do not pass an entity ID."""
await service.entity_service_call(
hass,
@ -530,9 +517,7 @@ async def test_call_with_omit_entity_id(
ha.ServiceCall("test_domain", "test_service"),
)
assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == []
assert len(mock_handle_entity_call.mock_calls) == 0
async def test_register_admin_service(hass, hass_read_only_user, hass_admin_user):
@ -644,96 +629,113 @@ async def test_domain_control_unknown(hass, mock_entities):
assert len(calls) == 0
async def test_domain_control_unauthorized(hass, hass_read_only_user, mock_entities):
async def test_domain_control_unauthorized(hass, hass_read_only_user):
"""Test domain verification in a service call with an unauthorized user."""
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch(
"homeassistant.helpers.entity_registry.async_get_registry",
return_value=mock_coro(Mock(entities=mock_entities)),
):
protected_mock_service = hass.helpers.service.verify_domain_control(
"test_domain"
)(mock_service_log)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=ha.Context(user_id=hass_read_only_user.id),
mock_registry(
hass,
{
"light.kitchen": ent_reg.RegistryEntry(
entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
)
},
)
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
mock_service_log
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
with pytest.raises(exceptions.Unauthorized):
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=ha.Context(user_id=hass_read_only_user.id),
)
assert len(calls) == 0
async def test_domain_control_admin(hass, hass_admin_user, mock_entities):
async def test_domain_control_admin(hass, hass_admin_user):
"""Test domain verification in a service call with an admin user."""
mock_registry(
hass,
{
"light.kitchen": ent_reg.RegistryEntry(
entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
)
},
)
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch(
"homeassistant.helpers.entity_registry.async_get_registry",
return_value=mock_coro(Mock(entities=mock_entities)),
):
protected_mock_service = hass.helpers.service.verify_domain_control(
"test_domain"
)(mock_service_log)
protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
mock_service_log
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=ha.Context(user_id=hass_admin_user.id),
)
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=ha.Context(user_id=hass_admin_user.id),
)
assert len(calls) == 1
assert len(calls) == 1
async def test_domain_control_no_user(hass, mock_entities):
async def test_domain_control_no_user(hass):
"""Test domain verification in a service call with no user."""
mock_registry(
hass,
{
"light.kitchen": ent_reg.RegistryEntry(
entity_id="light.kitchen", unique_id="kitchen", platform="test_domain",
)
},
)
calls = []
async def mock_service_log(call):
"""Define a protected service."""
calls.append(call)
with patch(
"homeassistant.helpers.entity_registry.async_get_registry",
return_value=mock_coro(Mock(entities=mock_entities)),
):
protected_mock_service = hass.helpers.service.verify_domain_control(
"test_domain"
)(mock_service_log)
protected_mock_service = hass.helpers.service.verify_domain_control("test_domain")(
mock_service_log
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
hass.services.async_register(
"test_domain", "test_service", protected_mock_service, schema=None
)
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=ha.Context(user_id=None),
)
await hass.services.async_call(
"test_domain",
"test_service",
{},
blocking=True,
context=ha.Context(user_id=None),
)
assert len(calls) == 1
assert len(calls) == 1
async def test_extract_from_service_available_device(hass):