diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index 627da2d18721..2ed8b58232d6 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -386,7 +386,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: continue stream.keepalive = True stream.add_provider("hls") - stream.start() + await stream.start() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, preload_stream) @@ -996,7 +996,7 @@ async def _async_stream_endpoint_url( stream.keepalive = camera_prefs.preload_stream stream.add_provider(fmt) - stream.start() + await stream.start() return stream.endpoint_url(fmt) diff --git a/homeassistant/components/nest/camera_sdm.py b/homeassistant/components/nest/camera_sdm.py index 6e14100e881e..61f8ead4ea38 100644 --- a/homeassistant/components/nest/camera_sdm.py +++ b/homeassistant/components/nest/camera_sdm.py @@ -175,7 +175,7 @@ class NestCamera(Camera): # Next attempt to catch a url will get a new one self._stream = None if self.stream: - self.stream.stop() + await self.stream.stop() self.stream = None return # Update the stream worker with the latest valid url diff --git a/homeassistant/components/stream/__init__.py b/homeassistant/components/stream/__init__.py index 895bdaf3201c..c33188fd71cb 100644 --- a/homeassistant/components/stream/__init__.py +++ b/homeassistant/components/stream/__init__.py @@ -16,6 +16,7 @@ to always keep workers active. """ from __future__ import annotations +import asyncio from collections.abc import Callable, Mapping import logging import re @@ -206,13 +207,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # Setup Recorder async_setup_recorder(hass) - @callback - def shutdown(event: Event) -> None: + async def shutdown(event: Event) -> None: """Stop all stream workers.""" for stream in hass.data[DOMAIN][ATTR_STREAMS]: stream.keepalive = False - stream.stop() - _LOGGER.info("Stopped stream workers") + if awaitables := [ + asyncio.create_task(stream.stop()) + for stream in hass.data[DOMAIN][ATTR_STREAMS] + ]: + await asyncio.wait(awaitables) + _LOGGER.debug("Stopped stream workers") hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, shutdown) @@ -236,6 +240,7 @@ class Stream: self._stream_label = stream_label self.keepalive = False self.access_token: str | None = None + self._start_stop_lock = asyncio.Lock() self._thread: threading.Thread | None = None self._thread_quit = threading.Event() self._outputs: dict[str, StreamOutput] = {} @@ -271,12 +276,11 @@ class Stream: """Add provider output stream.""" if not (provider := self._outputs.get(fmt)): - @callback - def idle_callback() -> None: + async def idle_callback() -> None: if ( not self.keepalive or fmt == RECORDER_PROVIDER ) and fmt in self._outputs: - self.remove_provider(self._outputs[fmt]) + await self.remove_provider(self._outputs[fmt]) self.check_idle() provider = PROVIDERS[fmt]( @@ -286,14 +290,14 @@ class Stream: return provider - def remove_provider(self, provider: StreamOutput) -> None: + async def remove_provider(self, provider: StreamOutput) -> None: """Remove provider output stream.""" if provider.name in self._outputs: self._outputs[provider.name].cleanup() del self._outputs[provider.name] if not self._outputs: - self.stop() + await self.stop() def check_idle(self) -> None: """Reset access token if all providers are idle.""" @@ -316,9 +320,14 @@ class Stream: if self._update_callback: self._update_callback() - def start(self) -> None: - """Start a stream.""" - if self._thread is None or not self._thread.is_alive(): + async def start(self) -> None: + """Start a stream. + + Uses an asyncio.Lock to avoid conflicts with _stop(). + """ + async with self._start_stop_lock: + if self._thread and self._thread.is_alive(): + return if self._thread is not None: # The thread must have crashed/exited. Join to clean up the # previous thread. @@ -329,7 +338,7 @@ class Stream: target=self._run_worker, ) self._thread.start() - self._logger.info( + self._logger.debug( "Started stream: %s", redact_credentials(str(self.source)) ) @@ -394,33 +403,39 @@ class Stream: redact_credentials(str(self.source)), ) - @callback - def worker_finished() -> None: + async def worker_finished() -> None: # The worker is no checking availability of the stream and can no longer track # availability so mark it as available, otherwise the frontend may not be able to # interact with the stream. if not self.available: self._async_update_state(True) + # We can call remove_provider() sequentially as the wrapped _stop() function + # which blocks internally is only called when the last provider is removed. for provider in self.outputs().values(): - self.remove_provider(provider) + await self.remove_provider(provider) - self.hass.loop.call_soon_threadsafe(worker_finished) + self.hass.create_task(worker_finished()) - def stop(self) -> None: + async def stop(self) -> None: """Remove outputs and access token.""" self._outputs = {} self.access_token = None if not self.keepalive: - self._stop() + await self._stop() - def _stop(self) -> None: - """Stop worker thread.""" - if self._thread is not None: + async def _stop(self) -> None: + """Stop worker thread. + + Uses an asyncio.Lock to avoid conflicts with start(). + """ + async with self._start_stop_lock: + if self._thread is None: + return self._thread_quit.set() - self._thread.join() + await self.hass.async_add_executor_job(self._thread.join) self._thread = None - self._logger.info( + self._logger.debug( "Stopped stream: %s", redact_credentials(str(self.source)) ) @@ -448,7 +463,7 @@ class Stream: ) recorder.video_path = video_path - self.start() + await self.start() self._logger.debug("Started a stream recording of %s seconds", duration) # Take advantage of lookback @@ -473,7 +488,7 @@ class Stream: """ self.add_provider(HLS_PROVIDER) - self.start() + await self.start() return await self._keyframe_converter.async_get_image( width=width, height=height ) diff --git a/homeassistant/components/stream/core.py b/homeassistant/components/stream/core.py index 8c0b867752e3..da18a5a6a088 100644 --- a/homeassistant/components/stream/core.py +++ b/homeassistant/components/stream/core.py @@ -3,9 +3,9 @@ from __future__ import annotations import asyncio from collections import deque -from collections.abc import Iterable +from collections.abc import Callable, Coroutine, Iterable import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from aiohttp import web import async_timeout @@ -192,7 +192,10 @@ class IdleTimer: """ def __init__( - self, hass: HomeAssistant, timeout: int, idle_callback: CALLBACK_TYPE + self, + hass: HomeAssistant, + timeout: int, + idle_callback: Callable[[], Coroutine[Any, Any, None]], ) -> None: """Initialize IdleTimer.""" self._hass = hass @@ -219,11 +222,12 @@ class IdleTimer: if self._unsub is not None: self._unsub() + @callback def fire(self, _now: datetime.datetime) -> None: """Invoke the idle timeout callback, called when the alarm fires.""" self.idle = True self._unsub = None - self._callback() + self._hass.async_create_task(self._callback()) class StreamOutput: @@ -349,7 +353,7 @@ class StreamView(HomeAssistantView): raise web.HTTPNotFound() # Start worker if not already started - stream.start() + await stream.start() return await self.handle(request, stream, sequence, part_num) diff --git a/homeassistant/components/stream/hls.py b/homeassistant/components/stream/hls.py index 23584b59fb94..8e78093d07aa 100644 --- a/homeassistant/components/stream/hls.py +++ b/homeassistant/components/stream/hls.py @@ -117,7 +117,7 @@ class HlsMasterPlaylistView(StreamView): ) -> web.Response: """Return m3u8 playlist.""" track = stream.add_provider(HLS_PROVIDER) - stream.start() + await stream.start() # Make sure at least two segments are ready (last one may not be complete) if not track.sequences and not await track.recv(): return web.HTTPNotFound() @@ -232,7 +232,7 @@ class HlsPlaylistView(StreamView): track: HlsStreamOutput = cast( HlsStreamOutput, stream.add_provider(HLS_PROVIDER) ) - stream.start() + await stream.start() hls_msn: str | int | None = request.query.get("_HLS_msn") hls_part: str | int | None = request.query.get("_HLS_part") diff --git a/tests/components/camera/test_init.py b/tests/components/camera/test_init.py index 1b30facf1de8..ba13bbd6c52b 100644 --- a/tests/components/camera/test_init.py +++ b/tests/components/camera/test_init.py @@ -3,7 +3,7 @@ import asyncio import base64 from http import HTTPStatus import io -from unittest.mock import Mock, PropertyMock, mock_open, patch +from unittest.mock import AsyncMock, Mock, PropertyMock, mock_open, patch import pytest @@ -410,6 +410,7 @@ async def test_preload_stream(hass, mock_stream): "homeassistant.components.demo.camera.DemoCamera.stream_source", return_value="http://example.com", ): + mock_create_stream.return_value.start = AsyncMock() assert await async_setup_component( hass, "camera", {DOMAIN: {"platform": "demo"}} ) diff --git a/tests/components/nest/test_camera_sdm.py b/tests/components/nest/test_camera_sdm.py index 42b236fda7ca..5c4194f46f68 100644 --- a/tests/components/nest/test_camera_sdm.py +++ b/tests/components/nest/test_camera_sdm.py @@ -158,6 +158,7 @@ async def mock_create_stream(hass) -> Mock: ) mock_stream.return_value.async_get_image = AsyncMock() mock_stream.return_value.async_get_image.return_value = IMAGE_BYTES_FROM_STREAM + mock_stream.return_value.start = AsyncMock() yield mock_stream @@ -370,6 +371,7 @@ async def test_refresh_expired_stream_token( # Request a stream for the camera entity to exercise nest cam + camera interaction # and shutdown on url expiration with patch("homeassistant.components.camera.create_stream") as create_stream: + create_stream.return_value.start = AsyncMock() hls_url = await camera.async_request_stream(hass, "camera.my_camera", fmt="hls") assert hls_url.startswith("/api/hls/") # Includes access token assert create_stream.called @@ -536,7 +538,8 @@ async def test_refresh_expired_stream_failure( # Request an HLS stream with patch("homeassistant.components.camera.create_stream") as create_stream: - + create_stream.return_value.start = AsyncMock() + create_stream.return_value.stop = AsyncMock() hls_url = await camera.async_request_stream(hass, "camera.my_camera", fmt="hls") assert hls_url.startswith("/api/hls/") # Includes access token assert create_stream.called @@ -555,6 +558,7 @@ async def test_refresh_expired_stream_failure( # Requesting an HLS stream will create an entirely new stream with patch("homeassistant.components.camera.create_stream") as create_stream: + create_stream.return_value.start = AsyncMock() # The HLS stream endpoint was invalidated, with a new auth token hls_url2 = await camera.async_request_stream( hass, "camera.my_camera", fmt="hls" diff --git a/tests/components/stream/test_hls.py b/tests/components/stream/test_hls.py index 8e01c55de840..7343b96ef9a2 100644 --- a/tests/components/stream/test_hls.py +++ b/tests/components/stream/test_hls.py @@ -144,7 +144,7 @@ async def test_hls_stream( # Request stream stream.add_provider(HLS_PROVIDER) - stream.start() + await stream.start() hls_client = await hls_stream(stream) @@ -171,7 +171,7 @@ async def test_hls_stream( stream_worker_sync.resume() # Stop stream, if it hasn't quit already - stream.stop() + await stream.stop() # Ensure playlist not accessible after stream ends fail_response = await hls_client.get() @@ -205,7 +205,7 @@ async def test_stream_timeout( # Request stream stream.add_provider(HLS_PROVIDER) - stream.start() + await stream.start() url = stream.endpoint_url(HLS_PROVIDER) http_client = await hass_client() @@ -218,6 +218,7 @@ async def test_stream_timeout( # Wait a minute future = dt_util.utcnow() + timedelta(minutes=1) async_fire_time_changed(hass, future) + await hass.async_block_till_done() # Fetch again to reset timer playlist_response = await http_client.get(parsed_url.path) @@ -249,10 +250,10 @@ async def test_stream_timeout_after_stop( # Request stream stream.add_provider(HLS_PROVIDER) - stream.start() + await stream.start() stream_worker_sync.resume() - stream.stop() + await stream.stop() # Wait 5 minutes and fire callback. Stream should already have been # stopped so this is a no-op. @@ -297,14 +298,14 @@ async def test_stream_retries(hass, setup_component, should_retry): mock_time.time.side_effect = time_side_effect # Request stream. Enable retries which are disabled by default in tests. should_retry.return_value = True - stream.start() + await stream.start() stream._thread.join() stream._thread = None assert av_open.call_count == 2 await hass.async_block_till_done() # Stop stream, if it hasn't quit already - stream.stop() + await stream.stop() # Stream marked initially available, then marked as failed, then marked available # before the final failure that exits the stream. @@ -351,7 +352,7 @@ async def test_hls_playlist_view(hass, setup_component, hls_stream, stream_worke ) stream_worker_sync.resume() - stream.stop() + await stream.stop() async def test_hls_max_segments(hass, setup_component, hls_stream, stream_worker_sync): @@ -400,7 +401,7 @@ async def test_hls_max_segments(hass, setup_component, hls_stream, stream_worker assert segment_response.status == HTTPStatus.OK stream_worker_sync.resume() - stream.stop() + await stream.stop() async def test_hls_playlist_view_discontinuity( @@ -438,7 +439,7 @@ async def test_hls_playlist_view_discontinuity( ) stream_worker_sync.resume() - stream.stop() + await stream.stop() async def test_hls_max_segments_discontinuity( @@ -481,7 +482,7 @@ async def test_hls_max_segments_discontinuity( ) stream_worker_sync.resume() - stream.stop() + await stream.stop() async def test_remove_incomplete_segment_on_exit( @@ -490,7 +491,7 @@ async def test_remove_incomplete_segment_on_exit( """Test that the incomplete segment gets removed when the worker thread quits.""" stream = create_stream(hass, STREAM_SOURCE, {}) stream_worker_sync.pause() - stream.start() + await stream.start() hls = stream.add_provider(HLS_PROVIDER) segment = Segment(sequence=0, stream_id=0, duration=SEGMENT_DURATION) @@ -511,4 +512,4 @@ async def test_remove_incomplete_segment_on_exit( await hass.async_block_till_done() assert segments[-1].complete assert len(segments) == 2 - stream.stop() + await stream.stop() diff --git a/tests/components/stream/test_ll_hls.py b/tests/components/stream/test_ll_hls.py index 9a0d94136b93..4aaec93d646a 100644 --- a/tests/components/stream/test_ll_hls.py +++ b/tests/components/stream/test_ll_hls.py @@ -144,7 +144,7 @@ async def test_ll_hls_stream(hass, hls_stream, stream_worker_sync): # Request stream stream.add_provider(HLS_PROVIDER) - stream.start() + await stream.start() hls_client = await hls_stream(stream) @@ -243,7 +243,7 @@ async def test_ll_hls_stream(hass, hls_stream, stream_worker_sync): stream_worker_sync.resume() # Stop stream, if it hasn't quit already - stream.stop() + await stream.stop() # Ensure playlist not accessible after stream ends fail_response = await hls_client.get() @@ -316,7 +316,7 @@ async def test_ll_hls_playlist_view(hass, hls_stream, stream_worker_sync): ) stream_worker_sync.resume() - stream.stop() + await stream.stop() async def test_ll_hls_msn(hass, hls_stream, stream_worker_sync, hls_sync): diff --git a/tests/components/stream/test_recorder.py b/tests/components/stream/test_recorder.py index 50aa4df3f1c5..9433cbd449d4 100644 --- a/tests/components/stream/test_recorder.py +++ b/tests/components/stream/test_recorder.py @@ -46,7 +46,7 @@ async def test_record_stream(hass, hass_client, record_worker_sync, h264_video): # thread completes and is shutdown completely to avoid thread leaks. await record_worker_sync.join() - stream.stop() + await stream.stop() async def test_record_lookback( @@ -59,14 +59,14 @@ async def test_record_lookback( # Start an HLS feed to enable lookback stream.add_provider(HLS_PROVIDER) - stream.start() + await stream.start() with patch.object(hass.config, "is_allowed_path", return_value=True): await stream.async_record("/example/path", lookback=4) # This test does not need recorder cleanup since it is not fully exercised - stream.stop() + await stream.stop() async def test_recorder_timeout(hass, hass_client, stream_worker_sync, h264_video): @@ -97,7 +97,7 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync, h264_vide assert mock_timeout.called stream_worker_sync.resume() - stream.stop() + await stream.stop() await hass.async_block_till_done() await hass.async_block_till_done() @@ -229,7 +229,7 @@ async def test_record_stream_audio( assert len(result.streams.audio) == expected_audio_streams result.close() - stream.stop() + await stream.stop() await hass.async_block_till_done() # Verify that the save worker was invoked, then block until its diff --git a/tests/components/stream/test_worker.py b/tests/components/stream/test_worker.py index 2a44dd644557..a70f2be81b8b 100644 --- a/tests/components/stream/test_worker.py +++ b/tests/components/stream/test_worker.py @@ -651,12 +651,12 @@ async def test_stream_stopped_while_decoding(hass): return py_av.open(stream_source, args, kwargs) with patch("av.open", new=blocking_open): - stream.start() + await stream.start() assert worker_open.wait(TIMEOUT) # Note: There is a race here where the worker could start as soon # as the wake event is sent, completing all decode work. worker_wake.set() - stream.stop() + await stream.stop() # Stream is still considered available when the worker was still active and asked to stop assert stream.available @@ -688,7 +688,7 @@ async def test_update_stream_source(hass): return py_av.open(stream_source, args, kwargs) with patch("av.open", new=blocking_open): - stream.start() + await stream.start() assert worker_open.wait(TIMEOUT) assert last_stream_source == STREAM_SOURCE assert stream.available @@ -704,7 +704,7 @@ async def test_update_stream_source(hass): assert stream.available # Cleanup - stream.stop() + await stream.stop() async def test_worker_log(hass, caplog): @@ -796,7 +796,7 @@ async def test_durations(hass, record_worker_sync): await record_worker_sync.join() - stream.stop() + await stream.stop() async def test_has_keyframe(hass, record_worker_sync, h264_video): @@ -836,7 +836,7 @@ async def test_has_keyframe(hass, record_worker_sync, h264_video): await record_worker_sync.join() - stream.stop() + await stream.stop() async def test_h265_video_is_hvc1(hass, record_worker_sync): @@ -871,7 +871,7 @@ async def test_h265_video_is_hvc1(hass, record_worker_sync): await record_worker_sync.join() - stream.stop() + await stream.stop() assert stream.get_diagnostics() == { "container_format": "mov,mp4,m4a,3gp,3g2,mj2", @@ -905,4 +905,4 @@ async def test_get_image(hass, record_worker_sync): assert await stream.async_get_image() == EMPTY_8_6_JPEG - stream.stop() + await stream.stop()