diff --git a/homeassistant/core.py b/homeassistant/core.py index 34d3dabc9a99..c94e56c6b570 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -77,6 +77,7 @@ from .exceptions import ( ServiceNotFound, Unauthorized, ) +from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior from .util import dt as dt_util, location, ulid as ulid_util from .util.async_ import ( fire_coroutine_threadsafe, @@ -105,6 +106,7 @@ STAGE_2_SHUTDOWN_TIMEOUT = 60 STAGE_3_SHUTDOWN_TIMEOUT = 30 block_async_io.enable() +restore_original_aiohttp_cancel_behavior() _T = TypeVar("_T") _R = TypeVar("_R") diff --git a/homeassistant/helpers/aiohttp_compat.py b/homeassistant/helpers/aiohttp_compat.py new file mode 100644 index 000000000000..1780cd053f59 --- /dev/null +++ b/homeassistant/helpers/aiohttp_compat.py @@ -0,0 +1,25 @@ +"""Helper to restore old aiohttp behavior.""" +from __future__ import annotations + +from aiohttp import web_protocol, web_server + + +class CancelOnDisconnectRequestHandler(web_protocol.RequestHandler): + """Request handler that cancels tasks on disconnect.""" + + def connection_lost(self, exc: BaseException | None) -> None: + """Handle connection lost.""" + task_handler = self._task_handler + super().connection_lost(exc) + if task_handler is not None: + task_handler.cancel() + + +def restore_original_aiohttp_cancel_behavior() -> None: + """Patch aiohttp to restore cancel behavior. + + Remove this once aiohttp 3.9 is released as we can use + https://github.com/aio-libs/aiohttp/pull/7128 + """ + web_protocol.RequestHandler = CancelOnDisconnectRequestHandler # type: ignore[misc] + web_server.RequestHandler = CancelOnDisconnectRequestHandler # type: ignore[misc] diff --git a/tests/helpers/test_aiohttp_compat.py b/tests/helpers/test_aiohttp_compat.py new file mode 100644 index 000000000000..749984dbc2ed --- /dev/null +++ b/tests/helpers/test_aiohttp_compat.py @@ -0,0 +1,55 @@ +"""Test the aiohttp compatibility shim.""" + +import asyncio +from contextlib import suppress + +from aiohttp import client, web, web_protocol, web_server +import pytest + +from homeassistant.helpers.aiohttp_compat import CancelOnDisconnectRequestHandler + + +@pytest.mark.allow_hosts(["127.0.0.1"]) +async def test_handler_cancellation(socket_enabled, unused_tcp_port_factory) -> None: + """Test that handler cancels the request on disconnect. + + From aiohttp tests/test_web_server.py + """ + assert web_protocol.RequestHandler is CancelOnDisconnectRequestHandler + assert web_server.RequestHandler is CancelOnDisconnectRequestHandler + + event = asyncio.Event() + port = unused_tcp_port_factory() + + async def on_request(_: web.Request) -> web.Response: + nonlocal event + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + event.set() + raise + else: + raise web.HTTPInternalServerError() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app) + await runner.setup() + + site = web.TCPSite(runner, host="127.0.0.1", port=port) + + await site.start() + + try: + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://127.0.0.1:{port}/") + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(event.wait(), timeout=1) + assert event.is_set(), "Request handler hasn't been cancelled" + finally: + await asyncio.gather(runner.shutdown(), site.stop())