Make Stream.stop() async (#73107)

* Make Stream.start() async
* Stop streams concurrently on shutdown
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
uvjustin 2022-06-08 02:10:53 +10:00 committed by GitHub
parent c6b835dd91
commit 73f2bca377
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 92 additions and 67 deletions

View File

@ -386,7 +386,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
continue
stream.keepalive = True
stream.add_provider("hls")
stream.start()
await stream.start()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, preload_stream)
@ -996,7 +996,7 @@ async def _async_stream_endpoint_url(
stream.keepalive = camera_prefs.preload_stream
stream.add_provider(fmt)
stream.start()
await stream.start()
return stream.endpoint_url(fmt)

View File

@ -175,7 +175,7 @@ class NestCamera(Camera):
# Next attempt to catch a url will get a new one
self._stream = None
if self.stream:
self.stream.stop()
await self.stream.stop()
self.stream = None
return
# Update the stream worker with the latest valid url

View File

@ -16,6 +16,7 @@ to always keep workers active.
"""
from __future__ import annotations
import asyncio
from collections.abc import Callable, Mapping
import logging
import re
@ -206,13 +207,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
# Setup Recorder
async_setup_recorder(hass)
@callback
def shutdown(event: Event) -> None:
async def shutdown(event: Event) -> None:
"""Stop all stream workers."""
for stream in hass.data[DOMAIN][ATTR_STREAMS]:
stream.keepalive = False
stream.stop()
_LOGGER.info("Stopped stream workers")
if awaitables := [
asyncio.create_task(stream.stop())
for stream in hass.data[DOMAIN][ATTR_STREAMS]
]:
await asyncio.wait(awaitables)
_LOGGER.debug("Stopped stream workers")
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, shutdown)
@ -236,6 +240,7 @@ class Stream:
self._stream_label = stream_label
self.keepalive = False
self.access_token: str | None = None
self._start_stop_lock = asyncio.Lock()
self._thread: threading.Thread | None = None
self._thread_quit = threading.Event()
self._outputs: dict[str, StreamOutput] = {}
@ -271,12 +276,11 @@ class Stream:
"""Add provider output stream."""
if not (provider := self._outputs.get(fmt)):
@callback
def idle_callback() -> None:
async def idle_callback() -> None:
if (
not self.keepalive or fmt == RECORDER_PROVIDER
) and fmt in self._outputs:
self.remove_provider(self._outputs[fmt])
await self.remove_provider(self._outputs[fmt])
self.check_idle()
provider = PROVIDERS[fmt](
@ -286,14 +290,14 @@ class Stream:
return provider
def remove_provider(self, provider: StreamOutput) -> None:
async def remove_provider(self, provider: StreamOutput) -> None:
"""Remove provider output stream."""
if provider.name in self._outputs:
self._outputs[provider.name].cleanup()
del self._outputs[provider.name]
if not self._outputs:
self.stop()
await self.stop()
def check_idle(self) -> None:
"""Reset access token if all providers are idle."""
@ -316,9 +320,14 @@ class Stream:
if self._update_callback:
self._update_callback()
def start(self) -> None:
"""Start a stream."""
if self._thread is None or not self._thread.is_alive():
async def start(self) -> None:
"""Start a stream.
Uses an asyncio.Lock to avoid conflicts with _stop().
"""
async with self._start_stop_lock:
if self._thread and self._thread.is_alive():
return
if self._thread is not None:
# The thread must have crashed/exited. Join to clean up the
# previous thread.
@ -329,7 +338,7 @@ class Stream:
target=self._run_worker,
)
self._thread.start()
self._logger.info(
self._logger.debug(
"Started stream: %s", redact_credentials(str(self.source))
)
@ -394,33 +403,39 @@ class Stream:
redact_credentials(str(self.source)),
)
@callback
def worker_finished() -> None:
async def worker_finished() -> None:
# The worker is no checking availability of the stream and can no longer track
# availability so mark it as available, otherwise the frontend may not be able to
# interact with the stream.
if not self.available:
self._async_update_state(True)
# We can call remove_provider() sequentially as the wrapped _stop() function
# which blocks internally is only called when the last provider is removed.
for provider in self.outputs().values():
self.remove_provider(provider)
await self.remove_provider(provider)
self.hass.loop.call_soon_threadsafe(worker_finished)
self.hass.create_task(worker_finished())
def stop(self) -> None:
async def stop(self) -> None:
"""Remove outputs and access token."""
self._outputs = {}
self.access_token = None
if not self.keepalive:
self._stop()
await self._stop()
def _stop(self) -> None:
"""Stop worker thread."""
if self._thread is not None:
async def _stop(self) -> None:
"""Stop worker thread.
Uses an asyncio.Lock to avoid conflicts with start().
"""
async with self._start_stop_lock:
if self._thread is None:
return
self._thread_quit.set()
self._thread.join()
await self.hass.async_add_executor_job(self._thread.join)
self._thread = None
self._logger.info(
self._logger.debug(
"Stopped stream: %s", redact_credentials(str(self.source))
)
@ -448,7 +463,7 @@ class Stream:
)
recorder.video_path = video_path
self.start()
await self.start()
self._logger.debug("Started a stream recording of %s seconds", duration)
# Take advantage of lookback
@ -473,7 +488,7 @@ class Stream:
"""
self.add_provider(HLS_PROVIDER)
self.start()
await self.start()
return await self._keyframe_converter.async_get_image(
width=width, height=height
)

View File

@ -3,9 +3,9 @@ from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import Iterable
from collections.abc import Callable, Coroutine, Iterable
import datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from aiohttp import web
import async_timeout
@ -192,7 +192,10 @@ class IdleTimer:
"""
def __init__(
self, hass: HomeAssistant, timeout: int, idle_callback: CALLBACK_TYPE
self,
hass: HomeAssistant,
timeout: int,
idle_callback: Callable[[], Coroutine[Any, Any, None]],
) -> None:
"""Initialize IdleTimer."""
self._hass = hass
@ -219,11 +222,12 @@ class IdleTimer:
if self._unsub is not None:
self._unsub()
@callback
def fire(self, _now: datetime.datetime) -> None:
"""Invoke the idle timeout callback, called when the alarm fires."""
self.idle = True
self._unsub = None
self._callback()
self._hass.async_create_task(self._callback())
class StreamOutput:
@ -349,7 +353,7 @@ class StreamView(HomeAssistantView):
raise web.HTTPNotFound()
# Start worker if not already started
stream.start()
await stream.start()
return await self.handle(request, stream, sequence, part_num)

View File

@ -117,7 +117,7 @@ class HlsMasterPlaylistView(StreamView):
) -> web.Response:
"""Return m3u8 playlist."""
track = stream.add_provider(HLS_PROVIDER)
stream.start()
await stream.start()
# Make sure at least two segments are ready (last one may not be complete)
if not track.sequences and not await track.recv():
return web.HTTPNotFound()
@ -232,7 +232,7 @@ class HlsPlaylistView(StreamView):
track: HlsStreamOutput = cast(
HlsStreamOutput, stream.add_provider(HLS_PROVIDER)
)
stream.start()
await stream.start()
hls_msn: str | int | None = request.query.get("_HLS_msn")
hls_part: str | int | None = request.query.get("_HLS_part")

View File

@ -3,7 +3,7 @@ import asyncio
import base64
from http import HTTPStatus
import io
from unittest.mock import Mock, PropertyMock, mock_open, patch
from unittest.mock import AsyncMock, Mock, PropertyMock, mock_open, patch
import pytest
@ -410,6 +410,7 @@ async def test_preload_stream(hass, mock_stream):
"homeassistant.components.demo.camera.DemoCamera.stream_source",
return_value="http://example.com",
):
mock_create_stream.return_value.start = AsyncMock()
assert await async_setup_component(
hass, "camera", {DOMAIN: {"platform": "demo"}}
)

View File

@ -158,6 +158,7 @@ async def mock_create_stream(hass) -> Mock:
)
mock_stream.return_value.async_get_image = AsyncMock()
mock_stream.return_value.async_get_image.return_value = IMAGE_BYTES_FROM_STREAM
mock_stream.return_value.start = AsyncMock()
yield mock_stream
@ -370,6 +371,7 @@ async def test_refresh_expired_stream_token(
# Request a stream for the camera entity to exercise nest cam + camera interaction
# and shutdown on url expiration
with patch("homeassistant.components.camera.create_stream") as create_stream:
create_stream.return_value.start = AsyncMock()
hls_url = await camera.async_request_stream(hass, "camera.my_camera", fmt="hls")
assert hls_url.startswith("/api/hls/") # Includes access token
assert create_stream.called
@ -536,7 +538,8 @@ async def test_refresh_expired_stream_failure(
# Request an HLS stream
with patch("homeassistant.components.camera.create_stream") as create_stream:
create_stream.return_value.start = AsyncMock()
create_stream.return_value.stop = AsyncMock()
hls_url = await camera.async_request_stream(hass, "camera.my_camera", fmt="hls")
assert hls_url.startswith("/api/hls/") # Includes access token
assert create_stream.called
@ -555,6 +558,7 @@ async def test_refresh_expired_stream_failure(
# Requesting an HLS stream will create an entirely new stream
with patch("homeassistant.components.camera.create_stream") as create_stream:
create_stream.return_value.start = AsyncMock()
# The HLS stream endpoint was invalidated, with a new auth token
hls_url2 = await camera.async_request_stream(
hass, "camera.my_camera", fmt="hls"

View File

@ -144,7 +144,7 @@ async def test_hls_stream(
# Request stream
stream.add_provider(HLS_PROVIDER)
stream.start()
await stream.start()
hls_client = await hls_stream(stream)
@ -171,7 +171,7 @@ async def test_hls_stream(
stream_worker_sync.resume()
# Stop stream, if it hasn't quit already
stream.stop()
await stream.stop()
# Ensure playlist not accessible after stream ends
fail_response = await hls_client.get()
@ -205,7 +205,7 @@ async def test_stream_timeout(
# Request stream
stream.add_provider(HLS_PROVIDER)
stream.start()
await stream.start()
url = stream.endpoint_url(HLS_PROVIDER)
http_client = await hass_client()
@ -218,6 +218,7 @@ async def test_stream_timeout(
# Wait a minute
future = dt_util.utcnow() + timedelta(minutes=1)
async_fire_time_changed(hass, future)
await hass.async_block_till_done()
# Fetch again to reset timer
playlist_response = await http_client.get(parsed_url.path)
@ -249,10 +250,10 @@ async def test_stream_timeout_after_stop(
# Request stream
stream.add_provider(HLS_PROVIDER)
stream.start()
await stream.start()
stream_worker_sync.resume()
stream.stop()
await stream.stop()
# Wait 5 minutes and fire callback. Stream should already have been
# stopped so this is a no-op.
@ -297,14 +298,14 @@ async def test_stream_retries(hass, setup_component, should_retry):
mock_time.time.side_effect = time_side_effect
# Request stream. Enable retries which are disabled by default in tests.
should_retry.return_value = True
stream.start()
await stream.start()
stream._thread.join()
stream._thread = None
assert av_open.call_count == 2
await hass.async_block_till_done()
# Stop stream, if it hasn't quit already
stream.stop()
await stream.stop()
# Stream marked initially available, then marked as failed, then marked available
# before the final failure that exits the stream.
@ -351,7 +352,7 @@ async def test_hls_playlist_view(hass, setup_component, hls_stream, stream_worke
)
stream_worker_sync.resume()
stream.stop()
await stream.stop()
async def test_hls_max_segments(hass, setup_component, hls_stream, stream_worker_sync):
@ -400,7 +401,7 @@ async def test_hls_max_segments(hass, setup_component, hls_stream, stream_worker
assert segment_response.status == HTTPStatus.OK
stream_worker_sync.resume()
stream.stop()
await stream.stop()
async def test_hls_playlist_view_discontinuity(
@ -438,7 +439,7 @@ async def test_hls_playlist_view_discontinuity(
)
stream_worker_sync.resume()
stream.stop()
await stream.stop()
async def test_hls_max_segments_discontinuity(
@ -481,7 +482,7 @@ async def test_hls_max_segments_discontinuity(
)
stream_worker_sync.resume()
stream.stop()
await stream.stop()
async def test_remove_incomplete_segment_on_exit(
@ -490,7 +491,7 @@ async def test_remove_incomplete_segment_on_exit(
"""Test that the incomplete segment gets removed when the worker thread quits."""
stream = create_stream(hass, STREAM_SOURCE, {})
stream_worker_sync.pause()
stream.start()
await stream.start()
hls = stream.add_provider(HLS_PROVIDER)
segment = Segment(sequence=0, stream_id=0, duration=SEGMENT_DURATION)
@ -511,4 +512,4 @@ async def test_remove_incomplete_segment_on_exit(
await hass.async_block_till_done()
assert segments[-1].complete
assert len(segments) == 2
stream.stop()
await stream.stop()

View File

@ -144,7 +144,7 @@ async def test_ll_hls_stream(hass, hls_stream, stream_worker_sync):
# Request stream
stream.add_provider(HLS_PROVIDER)
stream.start()
await stream.start()
hls_client = await hls_stream(stream)
@ -243,7 +243,7 @@ async def test_ll_hls_stream(hass, hls_stream, stream_worker_sync):
stream_worker_sync.resume()
# Stop stream, if it hasn't quit already
stream.stop()
await stream.stop()
# Ensure playlist not accessible after stream ends
fail_response = await hls_client.get()
@ -316,7 +316,7 @@ async def test_ll_hls_playlist_view(hass, hls_stream, stream_worker_sync):
)
stream_worker_sync.resume()
stream.stop()
await stream.stop()
async def test_ll_hls_msn(hass, hls_stream, stream_worker_sync, hls_sync):

View File

@ -46,7 +46,7 @@ async def test_record_stream(hass, hass_client, record_worker_sync, h264_video):
# thread completes and is shutdown completely to avoid thread leaks.
await record_worker_sync.join()
stream.stop()
await stream.stop()
async def test_record_lookback(
@ -59,14 +59,14 @@ async def test_record_lookback(
# Start an HLS feed to enable lookback
stream.add_provider(HLS_PROVIDER)
stream.start()
await stream.start()
with patch.object(hass.config, "is_allowed_path", return_value=True):
await stream.async_record("/example/path", lookback=4)
# This test does not need recorder cleanup since it is not fully exercised
stream.stop()
await stream.stop()
async def test_recorder_timeout(hass, hass_client, stream_worker_sync, h264_video):
@ -97,7 +97,7 @@ async def test_recorder_timeout(hass, hass_client, stream_worker_sync, h264_vide
assert mock_timeout.called
stream_worker_sync.resume()
stream.stop()
await stream.stop()
await hass.async_block_till_done()
await hass.async_block_till_done()
@ -229,7 +229,7 @@ async def test_record_stream_audio(
assert len(result.streams.audio) == expected_audio_streams
result.close()
stream.stop()
await stream.stop()
await hass.async_block_till_done()
# Verify that the save worker was invoked, then block until its

View File

@ -651,12 +651,12 @@ async def test_stream_stopped_while_decoding(hass):
return py_av.open(stream_source, args, kwargs)
with patch("av.open", new=blocking_open):
stream.start()
await stream.start()
assert worker_open.wait(TIMEOUT)
# Note: There is a race here where the worker could start as soon
# as the wake event is sent, completing all decode work.
worker_wake.set()
stream.stop()
await stream.stop()
# Stream is still considered available when the worker was still active and asked to stop
assert stream.available
@ -688,7 +688,7 @@ async def test_update_stream_source(hass):
return py_av.open(stream_source, args, kwargs)
with patch("av.open", new=blocking_open):
stream.start()
await stream.start()
assert worker_open.wait(TIMEOUT)
assert last_stream_source == STREAM_SOURCE
assert stream.available
@ -704,7 +704,7 @@ async def test_update_stream_source(hass):
assert stream.available
# Cleanup
stream.stop()
await stream.stop()
async def test_worker_log(hass, caplog):
@ -796,7 +796,7 @@ async def test_durations(hass, record_worker_sync):
await record_worker_sync.join()
stream.stop()
await stream.stop()
async def test_has_keyframe(hass, record_worker_sync, h264_video):
@ -836,7 +836,7 @@ async def test_has_keyframe(hass, record_worker_sync, h264_video):
await record_worker_sync.join()
stream.stop()
await stream.stop()
async def test_h265_video_is_hvc1(hass, record_worker_sync):
@ -871,7 +871,7 @@ async def test_h265_video_is_hvc1(hass, record_worker_sync):
await record_worker_sync.join()
stream.stop()
await stream.stop()
assert stream.get_diagnostics() == {
"container_format": "mov,mp4,m4a,3gp,3g2,mj2",
@ -905,4 +905,4 @@ async def test_get_image(hass, record_worker_sync):
assert await stream.async_get_image() == EMPTY_8_6_JPEG
stream.stop()
await stream.stop()