Determine how to run listeners at setup time instead of execution time (#41304)

This commit is contained in:
J. Nick Koston 2020-10-07 09:51:50 -05:00 committed by GitHub
parent 8d94dff75c
commit 9e1461da62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 216 additions and 119 deletions

View File

@ -149,6 +149,52 @@ def is_callback(func: Callable[..., Any]) -> bool:
return getattr(func, "_hass_callback", False) is True
@enum.unique
class HassJobType(enum.Enum):
"""Represent a job type."""
Coroutine = 1
Coroutinefunction = 2
Callback = 3
Executor = 4
class HassJob:
"""Represent a job to be run later.
We check the callable type in advance
so we can avoid checking it every time
we run the job.
"""
__slots__ = ("job_type", "target")
def __init__(self, target: Callable):
"""Create a job object."""
self.target = target
self.job_type = _get_callable_job_type(target)
def __repr__(self) -> str:
"""Return the job."""
return f"<Job {self.job_type} {self.target}>"
def _get_callable_job_type(target: Callable) -> HassJobType:
"""Determine the job type from the callable."""
# Check for partials to properly determine if coroutine function
check_target = target
while isinstance(check_target, functools.partial):
check_target = check_target.func
if asyncio.iscoroutine(check_target):
return HassJobType.Coroutine
if asyncio.iscoroutinefunction(check_target):
return HassJobType.Coroutinefunction
if is_callback(check_target):
return HassJobType.Callback
return HassJobType.Executor
class CoreState(enum.Enum):
"""Represent the current state of Home Assistant."""
@ -306,24 +352,32 @@ class HomeAssistant:
if target is None:
raise ValueError("Don't call async_add_job with None")
task = None
return self.async_add_hass_job(HassJob(target), *args)
# Check for partials to properly determine if coroutine function
check_target = target
while isinstance(check_target, functools.partial):
check_target = check_target.func
@callback
def async_add_hass_job(
self, hassjob: HassJob, *args: Any
) -> Optional[asyncio.Future]:
"""Add a HassJob from within the event loop.
if asyncio.iscoroutine(check_target):
task = self.loop.create_task(target) # type: ignore
elif asyncio.iscoroutinefunction(check_target):
task = self.loop.create_task(target(*args))
elif is_callback(check_target):
self.loop.call_soon(target, *args)
This method must be run in the event loop.
hassjob: HassJob to call.
args: parameters for method to call.
"""
if hassjob.job_type == HassJobType.Coroutine:
task = self.loop.create_task(hassjob.target) # type: ignore
elif hassjob.job_type == HassJobType.Coroutinefunction:
task = self.loop.create_task(hassjob.target(*args))
elif hassjob.job_type == HassJobType.Callback:
self.loop.call_soon(hassjob.target, *args)
return None
else:
task = self.loop.run_in_executor(None, target, *args) # type: ignore
task = self.loop.run_in_executor( # type: ignore
None, hassjob.target, *args
)
# If a task is scheduled
if self._track_task and task is not None:
if self._track_task:
self._pending_tasks.append(task)
return task
@ -366,6 +420,20 @@ class HomeAssistant:
"""Stop track tasks so you can't wait for all tasks to be done."""
self._track_task = False
@callback
def async_run_hass_job(self, hassjob: HassJob, *args: Any) -> None:
"""Run a HassJob from within the event loop.
This method must be run in the event loop.
hassjob: HassJob
args: parameters for method to call.
"""
if hassjob.job_type == HassJobType.Callback:
hassjob.target(*args)
else:
self.async_add_hass_job(hassjob, *args)
@callback
def async_run_job(
self, target: Callable[..., Union[None, Awaitable]], *args: Any
@ -377,14 +445,7 @@ class HomeAssistant:
target: target to call.
args: parameters for method to call.
"""
if (
not asyncio.iscoroutine(target)
and not asyncio.iscoroutinefunction(target)
and is_callback(target)
):
target(*args)
else:
self.async_add_job(target, *args)
self.async_run_hass_job(HassJob(target), *args)
def block_till_done(self) -> None:
"""Block until all pending work is done."""
@ -592,7 +653,7 @@ class EventBus:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a new event bus."""
self._listeners: Dict[str, List[Callable]] = {}
self._listeners: Dict[str, List[HassJob]] = {}
self._hass = hass
@callback
@ -648,8 +709,8 @@ class EventBus:
if not listeners:
return
for func in listeners:
self._hass.async_add_job(func, event)
for job in listeners:
self._hass.async_add_hass_job(job, event)
def listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE:
"""Listen for all events or events of a specific type.
@ -676,14 +737,15 @@ class EventBus:
This method must be run in the event loop.
"""
if event_type in self._listeners:
self._listeners[event_type].append(listener)
else:
self._listeners[event_type] = [listener]
return self._async_listen_job(event_type, HassJob(listener))
@callback
def _async_listen_job(self, event_type: str, hassjob: HassJob) -> CALLBACK_TYPE:
self._listeners.setdefault(event_type, []).append(hassjob)
def remove_listener() -> None:
"""Remove the listener."""
self._async_remove_listener(event_type, listener)
self._async_remove_listener(event_type, hassjob)
return remove_listener
@ -716,31 +778,36 @@ class EventBus:
This method must be run in the event loop.
"""
job: Optional[HassJob] = None
@callback
def onetime_listener(event: Event) -> None:
def _onetime_listener(event: Event) -> None:
"""Remove listener from event bus and then fire listener."""
if hasattr(onetime_listener, "run"):
nonlocal job
if hasattr(_onetime_listener, "run"):
return
# Set variable so that we will never run twice.
# Because the event bus loop might have async_fire queued multiple
# times, its possible this listener may already be lined up
# multiple times as well.
# This will make sure the second time it does nothing.
setattr(onetime_listener, "run", True)
self._async_remove_listener(event_type, onetime_listener)
setattr(_onetime_listener, "run", True)
assert job is not None
self._async_remove_listener(event_type, job)
self._hass.async_run_job(listener, event)
return self.async_listen(event_type, onetime_listener)
job = HassJob(_onetime_listener)
return self._async_listen_job(event_type, job)
@callback
def _async_remove_listener(self, event_type: str, listener: Callable) -> None:
def _async_remove_listener(self, event_type: str, hassjob: HassJob) -> None:
"""Remove a listener of a specific event_type.
This method must be run in the event loop.
"""
try:
self._listeners[event_type].remove(listener)
self._listeners[event_type].remove(hassjob)
# delete event_type list if empty
if not self._listeners[event_type]:
@ -748,7 +815,7 @@ class EventBus:
except (KeyError, ValueError):
# KeyError is key event_type listener did not exist
# ValueError if listener did not exist within event_type
_LOGGER.warning("Unable to remove unknown listener %s", listener)
_LOGGER.warning("Unable to remove unknown job listener %s", hassjob)
class State:
@ -1094,7 +1161,7 @@ class StateMachine:
class Service:
"""Representation of a callable service."""
__slots__ = ["func", "schema", "is_callback", "is_coroutinefunction"]
__slots__ = ["job", "schema"]
def __init__(
self,
@ -1103,13 +1170,8 @@ class Service:
context: Optional[Context] = None,
) -> None:
"""Initialize a service."""
self.func = func
self.job = HassJob(func)
self.schema = schema
# Properly detect wrapped functions
while isinstance(func, functools.partial):
func = func.func
self.is_callback = is_callback(func)
self.is_coroutinefunction = asyncio.iscoroutinefunction(func)
class ServiceCall:
@ -1377,12 +1439,12 @@ class ServiceRegistry:
self, handler: Service, service_call: ServiceCall
) -> None:
"""Execute a service."""
if handler.is_coroutinefunction:
await handler.func(service_call)
elif handler.is_callback:
handler.func(service_call)
if handler.job.job_type == HassJobType.Coroutinefunction:
await handler.job.target(service_call)
elif handler.job.job_type == HassJobType.Callback:
handler.job.target(service_call)
else:
await self._hass.async_add_executor_job(handler.func, service_call)
await self._hass.async_add_executor_job(handler.job.target, service_call)
class Config:

View File

@ -2,7 +2,7 @@
import logging
from typing import Any, Callable
from homeassistant.core import callback
from homeassistant.core import HassJob, callback
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.logging import catch_log_exception
@ -41,26 +41,25 @@ def async_dispatcher_connect(
if DATA_DISPATCHER not in hass.data:
hass.data[DATA_DISPATCHER] = {}
if signal not in hass.data[DATA_DISPATCHER]:
hass.data[DATA_DISPATCHER][signal] = []
wrapped_target = catch_log_exception(
target,
lambda *args: "Exception in {} when dispatching '{}': {}".format(
# Functions wrapped in partial do not have a __name__
getattr(target, "__name__", None) or str(target),
signal,
args,
),
job = HassJob(
catch_log_exception(
target,
lambda *args: "Exception in {} when dispatching '{}': {}".format(
# Functions wrapped in partial do not have a __name__
getattr(target, "__name__", None) or str(target),
signal,
args,
),
)
)
hass.data[DATA_DISPATCHER][signal].append(wrapped_target)
hass.data[DATA_DISPATCHER].setdefault(signal, []).append(job)
@callback
def async_remove_dispatcher() -> None:
"""Remove signal listener."""
try:
hass.data[DATA_DISPATCHER][signal].remove(wrapped_target)
hass.data[DATA_DISPATCHER][signal].remove(job)
except (KeyError, ValueError):
# KeyError is key target listener did not exist
# ValueError if listener did not exist within signal
@ -84,5 +83,5 @@ def async_dispatcher_send(hass: HomeAssistantType, signal: str, *args: Any) -> N
"""
target_list = hass.data.get(DATA_DISPATCHER, {}).get(signal, [])
for target in target_list:
hass.async_add_job(target, *args)
for job in target_list:
hass.async_add_hass_job(job, *args)

View File

@ -34,6 +34,7 @@ from homeassistant.const import (
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HomeAssistant,
State,
callback,
@ -174,6 +175,8 @@ def async_track_state_change(
else:
entity_ids = tuple(entity_id.lower() for entity_id in entity_ids)
job = HassJob(action)
@callback
def state_change_listener(event: Event) -> None:
"""Handle specific state changes."""
@ -192,8 +195,8 @@ def async_track_state_change(
if not match_to_state(new_state):
return
hass.async_run_job(
action,
hass.async_run_hass_job(
job,
event.data.get("entity_id"),
event.data.get("old_state"),
event.data.get("new_state"),
@ -246,9 +249,9 @@ def async_track_state_change_event(
if entity_id not in entity_callbacks:
return
for action in entity_callbacks[entity_id][:]:
for job in entity_callbacks[entity_id][:]:
try:
hass.async_run_job(action, event)
hass.async_run_hass_job(job, event)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error while processing state changed for %s", entity_id
@ -258,10 +261,12 @@ def async_track_state_change_event(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
)
job = HassJob(action)
entity_ids = _async_string_to_lower_list(entity_ids)
for entity_id in entity_ids:
entity_callbacks.setdefault(entity_id, []).append(action)
entity_callbacks.setdefault(entity_id, []).append(job)
@callback
def remove_listener() -> None:
@ -271,7 +276,7 @@ def async_track_state_change_event(
TRACK_STATE_CHANGE_CALLBACKS,
TRACK_STATE_CHANGE_LISTENER,
entity_ids,
action,
job,
)
return remove_listener
@ -283,14 +288,14 @@ def _async_remove_indexed_listeners(
data_key: str,
listener_key: str,
storage_keys: Iterable[str],
action: Callable[[Event], Any],
job: HassJob,
) -> None:
"""Remove a listener."""
callbacks = hass.data[data_key]
for storage_key in storage_keys:
callbacks[storage_key].remove(action)
callbacks[storage_key].remove(job)
if len(callbacks[storage_key]) == 0:
del callbacks[storage_key]
@ -322,9 +327,9 @@ def async_track_entity_registry_updated_event(
if entity_id not in entity_callbacks:
return
for action in entity_callbacks[entity_id][:]:
for job in entity_callbacks[entity_id][:]:
try:
hass.async_run_job(action, event)
hass.async_run_hass_job(job, event)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error while processing entity registry update for %s",
@ -335,10 +340,12 @@ def async_track_entity_registry_updated_event(
EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher
)
job = HassJob(action)
entity_ids = _async_string_to_lower_list(entity_ids)
for entity_id in entity_ids:
entity_callbacks.setdefault(entity_id, []).append(action)
entity_callbacks.setdefault(entity_id, []).append(job)
@callback
def remove_listener() -> None:
@ -348,7 +355,7 @@ def async_track_entity_registry_updated_event(
TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS,
TRACK_ENTITY_REGISTRY_UPDATED_LISTENER,
entity_ids,
action,
job,
)
return remove_listener
@ -365,9 +372,9 @@ def _async_dispatch_domain_event(
listeners = callbacks.get(domain, []) + callbacks.get(MATCH_ALL, [])
for action in listeners:
for job in listeners:
try:
hass.async_run_job(action, event)
hass.async_run_hass_job(job, event)
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error while processing event %s for domain %s", event, domain
@ -398,10 +405,12 @@ def async_track_state_added_domain(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
)
job = HassJob(action)
domains = _async_string_to_lower_list(domains)
for domain in domains:
domain_callbacks.setdefault(domain, []).append(action)
domain_callbacks.setdefault(domain, []).append(job)
@callback
def remove_listener() -> None:
@ -411,7 +420,7 @@ def async_track_state_added_domain(
TRACK_STATE_ADDED_DOMAIN_CALLBACKS,
TRACK_STATE_ADDED_DOMAIN_LISTENER,
domains,
action,
job,
)
return remove_listener
@ -441,10 +450,12 @@ def async_track_state_removed_domain(
EVENT_STATE_CHANGED, _async_state_change_dispatcher
)
job = HassJob(action)
domains = _async_string_to_lower_list(domains)
for domain in domains:
domain_callbacks.setdefault(domain, []).append(action)
domain_callbacks.setdefault(domain, []).append(job)
@callback
def remove_listener() -> None:
@ -454,7 +465,7 @@ def async_track_state_removed_domain(
TRACK_STATE_REMOVED_DOMAIN_CALLBACKS,
TRACK_STATE_REMOVED_DOMAIN_LISTENER,
domains,
action,
job,
)
return remove_listener
@ -665,6 +676,8 @@ def async_track_template(
"""
job = HassJob(action)
@callback
def _template_changed_listener(
event: Event, updates: List[TrackTemplateResult]
@ -691,8 +704,8 @@ def async_track_template(
):
return
hass.async_run_job(
action,
hass.async_run_hass_job(
job,
event.data.get("entity_id"),
event.data.get("old_state"),
event.data.get("new_state"),
@ -719,7 +732,7 @@ class _TrackTemplateResultInfo:
):
"""Handle removal / refresh of tracker init."""
self.hass = hass
self._action = action
self._job = HassJob(action)
for track_template_ in track_templates:
track_template_.template.hass = hass
@ -866,7 +879,7 @@ class _TrackTemplateResultInfo:
for track_result in updates:
self._last_result[track_result.template] = track_result.result
self.hass.async_run_job(self._action, event, updates)
self.hass.async_run_hass_job(self._job, event, updates)
TrackTemplateResultListener = Callable[
@ -951,6 +964,8 @@ def async_track_same_state(
async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None
async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None
job = HassJob(action)
@callback
def clear_listener() -> None:
"""Clear all unsub listener."""
@ -969,7 +984,7 @@ def async_track_same_state(
nonlocal async_remove_state_for_listener
async_remove_state_for_listener = None
clear_listener()
hass.async_run_job(action)
hass.async_run_hass_job(job)
@callback
def state_for_cancel_listener(event: Event) -> None:
@ -1005,14 +1020,18 @@ track_same_state = threaded_listener_factory(async_track_same_state)
@callback
@bind_hass
def async_track_point_in_time(
hass: HomeAssistant, action: Callable[..., None], point_in_time: datetime
hass: HomeAssistant,
action: Union[HassJob, Callable[..., None]],
point_in_time: datetime,
) -> CALLBACK_TYPE:
"""Add a listener that fires once after a specific point in time."""
job = action if isinstance(action, HassJob) else HassJob(action)
@callback
def utc_converter(utc_now: datetime) -> None:
"""Convert passed in UTC now to local now."""
hass.async_run_job(action, dt_util.as_local(utc_now))
hass.async_run_hass_job(job, dt_util.as_local(utc_now))
return async_track_point_in_utc_time(hass, utc_converter, point_in_time)
@ -1023,16 +1042,22 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@callback
@bind_hass
def async_track_point_in_utc_time(
hass: HomeAssistant, action: Callable[..., Any], point_in_time: datetime
hass: HomeAssistant,
action: Union[HassJob, Callable[..., None]],
point_in_time: datetime,
) -> CALLBACK_TYPE:
"""Add a listener that fires once after a specific point in UTC time."""
# Ensure point_in_time is UTC
utc_point_in_time = dt_util.as_utc(point_in_time)
# Since this is called once, we accept a HassJob so we can avoid
# having to figure out how to call the action every time its called.
job = action if isinstance(action, HassJob) else HassJob(action)
cancel_callback = hass.loop.call_at(
hass.loop.time() + point_in_time.timestamp() - time.time(),
hass.async_run_job,
action,
hass.async_run_hass_job,
job,
utc_point_in_time,
)
@ -1050,7 +1075,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim
@callback
@bind_hass
def async_call_later(
hass: HomeAssistant, delay: float, action: Callable[..., None]
hass: HomeAssistant, delay: float, action: Union[HassJob, Callable[..., None]]
) -> CALLBACK_TYPE:
"""Add a listener that is called in <delay>."""
return async_track_point_in_utc_time(
@ -1071,6 +1096,8 @@ def async_track_time_interval(
"""Add a listener that fires repetitively at every timedelta interval."""
remove = None
job = HassJob(action)
def next_interval() -> datetime:
"""Return the next interval."""
return dt_util.utcnow() + interval
@ -1080,7 +1107,7 @@ def async_track_time_interval(
"""Handle elapsed intervals."""
nonlocal remove
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
hass.async_run_job(action, now)
hass.async_run_hass_job(job, now)
remove = async_track_point_in_utc_time(hass, interval_listener, next_interval())
@ -1196,6 +1223,8 @@ def async_track_utc_time_change(
local: bool = False,
) -> CALLBACK_TYPE:
"""Add a listener that will fire if time matches a pattern."""
job = HassJob(action)
# We do not have to wrap the function with time pattern matching logic
# if no pattern given
if all(val is None for val in (hour, minute, second)):
@ -1203,7 +1232,7 @@ def async_track_utc_time_change(
@callback
def time_change_listener(event: Event) -> None:
"""Fire every time event that comes in."""
hass.async_run_job(action, event.data[ATTR_NOW])
hass.async_run_hass_job(job, event.data[ATTR_NOW])
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
@ -1233,7 +1262,7 @@ def async_track_utc_time_change(
nonlocal next_time, cancel_callback
now = pattern_utc_now()
hass.async_run_job(action, dt_util.as_local(now) if local else now)
hass.async_run_hass_job(job, dt_util.as_local(now) if local else now)
calculate_next(now + timedelta(seconds=1))

View File

@ -9,7 +9,7 @@ import urllib.error
import aiohttp
import requests
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
from homeassistant.helpers import entity, event
from homeassistant.util.dt import utcnow
@ -48,6 +48,7 @@ class DataUpdateCoordinator(Generic[T]):
self.data: Optional[T] = None
self._listeners: List[CALLBACK_TYPE] = []
self._job = HassJob(self._handle_refresh_interval)
self._unsub_refresh: Optional[CALLBACK_TYPE] = None
self._request_refresh_task: Optional[asyncio.TimerHandle] = None
self.last_update_success = True
@ -108,7 +109,7 @@ class DataUpdateCoordinator(Generic[T]):
# as long as the update process takes less than a second
self._unsub_refresh = event.async_track_point_in_utc_time(
self.hass,
self._handle_refresh_interval,
self._job,
utcnow().replace(microsecond=0) + self.update_interval,
)

View File

@ -269,7 +269,7 @@ async def test_turn_on_to_not_block_for_domains_without_service(hass):
"homeassistant.core.ServiceRegistry.async_call",
return_value=None,
) as mock_call:
await service.func(service_call)
await service.job.target(service_call)
assert mock_call.call_count == 2
assert mock_call.call_args_list[0][0] == (

View File

@ -415,6 +415,10 @@ def legacy_patchable_time():
# Ensure point_in_time is UTC
point_in_time = event.dt_util.as_utc(point_in_time)
# Since this is called once, we accept a HassJob so we can avoid
# having to figure out how to call the action every time its called.
job = action if isinstance(action, ha.HassJob) else ha.HassJob(action)
@ha.callback
def point_in_time_listener(event):
"""Listen for matching time_changed events."""
@ -431,7 +435,7 @@ def legacy_patchable_time():
setattr(point_in_time_listener, "run", True)
async_unsub()
hass.async_run_job(action, now)
hass.async_run_hass_job(job, now)
async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED, point_in_time_listener)
@ -443,6 +447,8 @@ def legacy_patchable_time():
hass, action, hour=None, minute=None, second=None, local=False
):
"""Add a listener that will fire if time matches a pattern."""
job = ha.HassJob(action)
# We do not have to wrap the function with time pattern matching logic
# if no pattern given
if all(val is None for val in (hour, minute, second)):
@ -450,7 +456,7 @@ def legacy_patchable_time():
@ha.callback
def time_change_listener(ev) -> None:
"""Fire every time event that comes in."""
hass.async_run_job(action, ev.data[ATTR_NOW])
hass.async_run_hass_job(job, ev.data[ATTR_NOW])
return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener)
@ -487,8 +493,8 @@ def legacy_patchable_time():
last_now = now
if next_time <= now:
hass.async_run_job(
action, event.dt_util.as_local(now) if local else now
hass.async_run_hass_job(
job, event.dt_util.as_local(now) if local else now
)
calculate_next(now + datetime.timedelta(seconds=1))

View File

@ -48,43 +48,43 @@ def test_split_entity_id():
assert ha.split_entity_id("domain.object_id") == ["domain", "object_id"]
def test_async_add_job_schedule_callback():
def test_async_add_hass_job_schedule_callback():
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock()
job = MagicMock()
ha.HomeAssistant.async_add_job(hass, ha.callback(job))
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(ha.callback(job)))
assert len(hass.loop.call_soon.mock_calls) == 1
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0
def test_async_add_job_schedule_partial_callback():
def test_async_add_hass_job_schedule_partial_callback():
"""Test that we schedule partial coros and add jobs to the job pool."""
hass = MagicMock()
job = MagicMock()
partial = functools.partial(ha.callback(job))
ha.HomeAssistant.async_add_job(hass, partial)
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(partial))
assert len(hass.loop.call_soon.mock_calls) == 1
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0
def test_async_add_job_schedule_coroutinefunction(loop):
def test_async_add_hass_job_schedule_coroutinefunction(loop):
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock(loop=MagicMock(wraps=loop))
async def job():
pass
ha.HomeAssistant.async_add_job(hass, job)
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(job))
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 1
assert len(hass.add_job.mock_calls) == 0
def test_async_add_job_schedule_partial_coroutinefunction(loop):
def test_async_add_hass_job_schedule_partial_coroutinefunction(loop):
"""Test that we schedule partial coros and add jobs to the job pool."""
hass = MagicMock(loop=MagicMock(wraps=loop))
@ -93,20 +93,20 @@ def test_async_add_job_schedule_partial_coroutinefunction(loop):
partial = functools.partial(job)
ha.HomeAssistant.async_add_job(hass, partial)
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(partial))
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 1
assert len(hass.add_job.mock_calls) == 0
def test_async_add_job_add_threaded_job_to_pool():
def test_async_add_job_add_hass_threaded_job_to_pool():
"""Test that we schedule coroutines and add jobs to the job pool."""
hass = MagicMock()
def job():
pass
ha.HomeAssistant.async_add_job(hass, job)
ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(job))
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.loop.run_in_executor.mock_calls) == 1
@ -125,7 +125,7 @@ def test_async_create_task_schedule_coroutine(loop):
assert len(hass.add_job.mock_calls) == 0
def test_async_run_job_calls_callback():
def test_async_run_hass_job_calls_callback():
"""Test that the callback annotation is respected."""
hass = MagicMock()
calls = []
@ -133,12 +133,12 @@ def test_async_run_job_calls_callback():
def job():
calls.append(1)
ha.HomeAssistant.async_run_job(hass, ha.callback(job))
ha.HomeAssistant.async_run_hass_job(hass, ha.HassJob(ha.callback(job)))
assert len(calls) == 1
assert len(hass.async_add_job.mock_calls) == 0
def test_async_run_job_delegates_non_async():
def test_async_run_hass_job_delegates_non_async():
"""Test that the callback annotation is respected."""
hass = MagicMock()
calls = []
@ -146,9 +146,9 @@ def test_async_run_job_delegates_non_async():
def job():
calls.append(1)
ha.HomeAssistant.async_run_job(hass, job)
ha.HomeAssistant.async_run_hass_job(hass, ha.HassJob(job))
assert len(calls) == 0
assert len(hass.async_add_job.mock_calls) == 1
assert len(hass.async_add_hass_job.mock_calls) == 1
def test_stage_shutdown():