Handle cancellation in ServiceRegistry.async_call (#33644)

This commit is contained in:
Phil Bruckner 2020-04-04 17:36:33 -05:00 committed by GitHub
parent d7e9959442
commit bf1b408038
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 18 deletions

View File

@ -28,6 +28,7 @@ from typing import (
Optional,
Set,
TypeVar,
Union,
)
import uuid
@ -1224,29 +1225,57 @@ class ServiceRegistry:
context=context,
)
coro = self._execute_service(handler, service_call)
if not blocking:
self._hass.async_create_task(self._safe_execute(handler, service_call))
self._run_service_in_background(coro, service_call)
return None
task = self._hass.async_create_task(coro)
try:
async with timeout(limit):
await asyncio.shield(self._execute_service(handler, service_call))
return True
except asyncio.TimeoutError:
return False
await asyncio.wait({task}, timeout=limit)
except asyncio.CancelledError:
# Task calling us was cancelled, so cancel service call task, and wait for
# it to be cancelled, within reason, before leaving.
_LOGGER.debug("Service call was cancelled: %s", service_call)
task.cancel()
await asyncio.wait({task}, timeout=SERVICE_CALL_LIMIT)
raise
async def _safe_execute(self, handler: Service, service_call: ServiceCall) -> None:
"""Execute a service and catch exceptions."""
try:
await self._execute_service(handler, service_call)
except Unauthorized:
_LOGGER.warning(
"Unauthorized service called %s/%s",
service_call.domain,
service_call.service,
)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error executing service %s", service_call)
if task.cancelled():
# Service call task was cancelled some other way, such as during shutdown.
_LOGGER.debug("Service was cancelled: %s", service_call)
raise asyncio.CancelledError
if task.done():
# Propagate any exceptions that might have happened during service call.
task.result()
# Service call completed successfully!
return True
# Service call task did not complete before timeout expired.
# Let it keep running in background.
self._run_service_in_background(task, service_call)
_LOGGER.debug("Service did not complete before timeout: %s", service_call)
return False
def _run_service_in_background(
self, coro_or_task: Union[Coroutine, asyncio.Task], service_call: ServiceCall
) -> None:
"""Run service call in background, catching and logging any exceptions."""
async def catch_exceptions() -> None:
try:
await coro_or_task
except Unauthorized:
_LOGGER.warning(
"Unauthorized service called %s/%s",
service_call.domain,
service_call.service,
)
except asyncio.CancelledError:
_LOGGER.debug("Service was cancelled: %s", service_call)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error executing service: %s", service_call)
self._hass.async_create_task(catch_exceptions())
async def _execute_service(
self, handler: Service, service_call: ServiceCall

View File

@ -1214,6 +1214,42 @@ async def test_async_functions_with_callback(hass):
assert len(runs) == 3
@pytest.mark.parametrize("cancel_call", [True, False])
async def test_cancel_service_task(hass, cancel_call):
"""Test cancellation."""
service_called = asyncio.Event()
service_cancelled = False
async def service_handler(call):
nonlocal service_cancelled
service_called.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
service_cancelled = True
raise
hass.services.async_register("test_domain", "test_service", service_handler)
call_task = hass.async_create_task(
hass.services.async_call("test_domain", "test_service", blocking=True)
)
tasks_1 = asyncio.all_tasks()
await asyncio.wait_for(service_called.wait(), timeout=1)
tasks_2 = asyncio.all_tasks() - tasks_1
assert len(tasks_2) == 1
service_task = tasks_2.pop()
if cancel_call:
call_task.cancel()
else:
service_task.cancel()
with pytest.raises(asyncio.CancelledError):
await call_task
assert service_cancelled
def test_valid_entity_id():
"""Test valid entity ID."""
for invalid in [