diff --git a/homeassistant/core.py b/homeassistant/core.py index 436580a78580..fa0e294e52e9 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -149,6 +149,52 @@ def is_callback(func: Callable[..., Any]) -> bool: return getattr(func, "_hass_callback", False) is True +@enum.unique +class HassJobType(enum.Enum): + """Represent a job type.""" + + Coroutine = 1 + Coroutinefunction = 2 + Callback = 3 + Executor = 4 + + +class HassJob: + """Represent a job to be run later. + + We check the callable type in advance + so we can avoid checking it every time + we run the job. + """ + + __slots__ = ("job_type", "target") + + def __init__(self, target: Callable): + """Create a job object.""" + self.target = target + self.job_type = _get_callable_job_type(target) + + def __repr__(self) -> str: + """Return the job.""" + return f"" + + +def _get_callable_job_type(target: Callable) -> HassJobType: + """Determine the job type from the callable.""" + # Check for partials to properly determine if coroutine function + check_target = target + while isinstance(check_target, functools.partial): + check_target = check_target.func + + if asyncio.iscoroutine(check_target): + return HassJobType.Coroutine + if asyncio.iscoroutinefunction(check_target): + return HassJobType.Coroutinefunction + if is_callback(check_target): + return HassJobType.Callback + return HassJobType.Executor + + class CoreState(enum.Enum): """Represent the current state of Home Assistant.""" @@ -306,24 +352,32 @@ class HomeAssistant: if target is None: raise ValueError("Don't call async_add_job with None") - task = None + return self.async_add_hass_job(HassJob(target), *args) - # Check for partials to properly determine if coroutine function - check_target = target - while isinstance(check_target, functools.partial): - check_target = check_target.func + @callback + def async_add_hass_job( + self, hassjob: HassJob, *args: Any + ) -> Optional[asyncio.Future]: + """Add a HassJob from within the event loop. - if asyncio.iscoroutine(check_target): - task = self.loop.create_task(target) # type: ignore - elif asyncio.iscoroutinefunction(check_target): - task = self.loop.create_task(target(*args)) - elif is_callback(check_target): - self.loop.call_soon(target, *args) + This method must be run in the event loop. + hassjob: HassJob to call. + args: parameters for method to call. + """ + if hassjob.job_type == HassJobType.Coroutine: + task = self.loop.create_task(hassjob.target) # type: ignore + elif hassjob.job_type == HassJobType.Coroutinefunction: + task = self.loop.create_task(hassjob.target(*args)) + elif hassjob.job_type == HassJobType.Callback: + self.loop.call_soon(hassjob.target, *args) + return None else: - task = self.loop.run_in_executor(None, target, *args) # type: ignore + task = self.loop.run_in_executor( # type: ignore + None, hassjob.target, *args + ) # If a task is scheduled - if self._track_task and task is not None: + if self._track_task: self._pending_tasks.append(task) return task @@ -366,6 +420,20 @@ class HomeAssistant: """Stop track tasks so you can't wait for all tasks to be done.""" self._track_task = False + @callback + def async_run_hass_job(self, hassjob: HassJob, *args: Any) -> None: + """Run a HassJob from within the event loop. + + This method must be run in the event loop. + + hassjob: HassJob + args: parameters for method to call. + """ + if hassjob.job_type == HassJobType.Callback: + hassjob.target(*args) + else: + self.async_add_hass_job(hassjob, *args) + @callback def async_run_job( self, target: Callable[..., Union[None, Awaitable]], *args: Any @@ -377,14 +445,7 @@ class HomeAssistant: target: target to call. args: parameters for method to call. """ - if ( - not asyncio.iscoroutine(target) - and not asyncio.iscoroutinefunction(target) - and is_callback(target) - ): - target(*args) - else: - self.async_add_job(target, *args) + self.async_run_hass_job(HassJob(target), *args) def block_till_done(self) -> None: """Block until all pending work is done.""" @@ -592,7 +653,7 @@ class EventBus: def __init__(self, hass: HomeAssistant) -> None: """Initialize a new event bus.""" - self._listeners: Dict[str, List[Callable]] = {} + self._listeners: Dict[str, List[HassJob]] = {} self._hass = hass @callback @@ -648,8 +709,8 @@ class EventBus: if not listeners: return - for func in listeners: - self._hass.async_add_job(func, event) + for job in listeners: + self._hass.async_add_hass_job(job, event) def listen(self, event_type: str, listener: Callable) -> CALLBACK_TYPE: """Listen for all events or events of a specific type. @@ -676,14 +737,15 @@ class EventBus: This method must be run in the event loop. """ - if event_type in self._listeners: - self._listeners[event_type].append(listener) - else: - self._listeners[event_type] = [listener] + return self._async_listen_job(event_type, HassJob(listener)) + + @callback + def _async_listen_job(self, event_type: str, hassjob: HassJob) -> CALLBACK_TYPE: + self._listeners.setdefault(event_type, []).append(hassjob) def remove_listener() -> None: """Remove the listener.""" - self._async_remove_listener(event_type, listener) + self._async_remove_listener(event_type, hassjob) return remove_listener @@ -716,31 +778,36 @@ class EventBus: This method must be run in the event loop. """ + job: Optional[HassJob] = None @callback - def onetime_listener(event: Event) -> None: + def _onetime_listener(event: Event) -> None: """Remove listener from event bus and then fire listener.""" - if hasattr(onetime_listener, "run"): + nonlocal job + if hasattr(_onetime_listener, "run"): return # Set variable so that we will never run twice. # Because the event bus loop might have async_fire queued multiple # times, its possible this listener may already be lined up # multiple times as well. # This will make sure the second time it does nothing. - setattr(onetime_listener, "run", True) - self._async_remove_listener(event_type, onetime_listener) + setattr(_onetime_listener, "run", True) + assert job is not None + self._async_remove_listener(event_type, job) self._hass.async_run_job(listener, event) - return self.async_listen(event_type, onetime_listener) + job = HassJob(_onetime_listener) + + return self._async_listen_job(event_type, job) @callback - def _async_remove_listener(self, event_type: str, listener: Callable) -> None: + def _async_remove_listener(self, event_type: str, hassjob: HassJob) -> None: """Remove a listener of a specific event_type. This method must be run in the event loop. """ try: - self._listeners[event_type].remove(listener) + self._listeners[event_type].remove(hassjob) # delete event_type list if empty if not self._listeners[event_type]: @@ -748,7 +815,7 @@ class EventBus: except (KeyError, ValueError): # KeyError is key event_type listener did not exist # ValueError if listener did not exist within event_type - _LOGGER.warning("Unable to remove unknown listener %s", listener) + _LOGGER.warning("Unable to remove unknown job listener %s", hassjob) class State: @@ -1094,7 +1161,7 @@ class StateMachine: class Service: """Representation of a callable service.""" - __slots__ = ["func", "schema", "is_callback", "is_coroutinefunction"] + __slots__ = ["job", "schema"] def __init__( self, @@ -1103,13 +1170,8 @@ class Service: context: Optional[Context] = None, ) -> None: """Initialize a service.""" - self.func = func + self.job = HassJob(func) self.schema = schema - # Properly detect wrapped functions - while isinstance(func, functools.partial): - func = func.func - self.is_callback = is_callback(func) - self.is_coroutinefunction = asyncio.iscoroutinefunction(func) class ServiceCall: @@ -1377,12 +1439,12 @@ class ServiceRegistry: self, handler: Service, service_call: ServiceCall ) -> None: """Execute a service.""" - if handler.is_coroutinefunction: - await handler.func(service_call) - elif handler.is_callback: - handler.func(service_call) + if handler.job.job_type == HassJobType.Coroutinefunction: + await handler.job.target(service_call) + elif handler.job.job_type == HassJobType.Callback: + handler.job.target(service_call) else: - await self._hass.async_add_executor_job(handler.func, service_call) + await self._hass.async_add_executor_job(handler.job.target, service_call) class Config: diff --git a/homeassistant/helpers/dispatcher.py b/homeassistant/helpers/dispatcher.py index bb6fa3a735db..cdf24ec23e98 100644 --- a/homeassistant/helpers/dispatcher.py +++ b/homeassistant/helpers/dispatcher.py @@ -2,7 +2,7 @@ import logging from typing import Any, Callable -from homeassistant.core import callback +from homeassistant.core import HassJob, callback from homeassistant.loader import bind_hass from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.logging import catch_log_exception @@ -41,26 +41,25 @@ def async_dispatcher_connect( if DATA_DISPATCHER not in hass.data: hass.data[DATA_DISPATCHER] = {} - if signal not in hass.data[DATA_DISPATCHER]: - hass.data[DATA_DISPATCHER][signal] = [] - - wrapped_target = catch_log_exception( - target, - lambda *args: "Exception in {} when dispatching '{}': {}".format( - # Functions wrapped in partial do not have a __name__ - getattr(target, "__name__", None) or str(target), - signal, - args, - ), + job = HassJob( + catch_log_exception( + target, + lambda *args: "Exception in {} when dispatching '{}': {}".format( + # Functions wrapped in partial do not have a __name__ + getattr(target, "__name__", None) or str(target), + signal, + args, + ), + ) ) - hass.data[DATA_DISPATCHER][signal].append(wrapped_target) + hass.data[DATA_DISPATCHER].setdefault(signal, []).append(job) @callback def async_remove_dispatcher() -> None: """Remove signal listener.""" try: - hass.data[DATA_DISPATCHER][signal].remove(wrapped_target) + hass.data[DATA_DISPATCHER][signal].remove(job) except (KeyError, ValueError): # KeyError is key target listener did not exist # ValueError if listener did not exist within signal @@ -84,5 +83,5 @@ def async_dispatcher_send(hass: HomeAssistantType, signal: str, *args: Any) -> N """ target_list = hass.data.get(DATA_DISPATCHER, {}).get(signal, []) - for target in target_list: - hass.async_add_job(target, *args) + for job in target_list: + hass.async_add_hass_job(job, *args) diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index c73196db6044..2bc6f77664d3 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -34,6 +34,7 @@ from homeassistant.const import ( from homeassistant.core import ( CALLBACK_TYPE, Event, + HassJob, HomeAssistant, State, callback, @@ -174,6 +175,8 @@ def async_track_state_change( else: entity_ids = tuple(entity_id.lower() for entity_id in entity_ids) + job = HassJob(action) + @callback def state_change_listener(event: Event) -> None: """Handle specific state changes.""" @@ -192,8 +195,8 @@ def async_track_state_change( if not match_to_state(new_state): return - hass.async_run_job( - action, + hass.async_run_hass_job( + job, event.data.get("entity_id"), event.data.get("old_state"), event.data.get("new_state"), @@ -246,9 +249,9 @@ def async_track_state_change_event( if entity_id not in entity_callbacks: return - for action in entity_callbacks[entity_id][:]: + for job in entity_callbacks[entity_id][:]: try: - hass.async_run_job(action, event) + hass.async_run_hass_job(job, event) except Exception: # pylint: disable=broad-except _LOGGER.exception( "Error while processing state changed for %s", entity_id @@ -258,10 +261,12 @@ def async_track_state_change_event( EVENT_STATE_CHANGED, _async_state_change_dispatcher ) + job = HassJob(action) + entity_ids = _async_string_to_lower_list(entity_ids) for entity_id in entity_ids: - entity_callbacks.setdefault(entity_id, []).append(action) + entity_callbacks.setdefault(entity_id, []).append(job) @callback def remove_listener() -> None: @@ -271,7 +276,7 @@ def async_track_state_change_event( TRACK_STATE_CHANGE_CALLBACKS, TRACK_STATE_CHANGE_LISTENER, entity_ids, - action, + job, ) return remove_listener @@ -283,14 +288,14 @@ def _async_remove_indexed_listeners( data_key: str, listener_key: str, storage_keys: Iterable[str], - action: Callable[[Event], Any], + job: HassJob, ) -> None: """Remove a listener.""" callbacks = hass.data[data_key] for storage_key in storage_keys: - callbacks[storage_key].remove(action) + callbacks[storage_key].remove(job) if len(callbacks[storage_key]) == 0: del callbacks[storage_key] @@ -322,9 +327,9 @@ def async_track_entity_registry_updated_event( if entity_id not in entity_callbacks: return - for action in entity_callbacks[entity_id][:]: + for job in entity_callbacks[entity_id][:]: try: - hass.async_run_job(action, event) + hass.async_run_hass_job(job, event) except Exception: # pylint: disable=broad-except _LOGGER.exception( "Error while processing entity registry update for %s", @@ -335,10 +340,12 @@ def async_track_entity_registry_updated_event( EVENT_ENTITY_REGISTRY_UPDATED, _async_entity_registry_updated_dispatcher ) + job = HassJob(action) + entity_ids = _async_string_to_lower_list(entity_ids) for entity_id in entity_ids: - entity_callbacks.setdefault(entity_id, []).append(action) + entity_callbacks.setdefault(entity_id, []).append(job) @callback def remove_listener() -> None: @@ -348,7 +355,7 @@ def async_track_entity_registry_updated_event( TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, TRACK_ENTITY_REGISTRY_UPDATED_LISTENER, entity_ids, - action, + job, ) return remove_listener @@ -365,9 +372,9 @@ def _async_dispatch_domain_event( listeners = callbacks.get(domain, []) + callbacks.get(MATCH_ALL, []) - for action in listeners: + for job in listeners: try: - hass.async_run_job(action, event) + hass.async_run_hass_job(job, event) except Exception: # pylint: disable=broad-except _LOGGER.exception( "Error while processing event %s for domain %s", event, domain @@ -398,10 +405,12 @@ def async_track_state_added_domain( EVENT_STATE_CHANGED, _async_state_change_dispatcher ) + job = HassJob(action) + domains = _async_string_to_lower_list(domains) for domain in domains: - domain_callbacks.setdefault(domain, []).append(action) + domain_callbacks.setdefault(domain, []).append(job) @callback def remove_listener() -> None: @@ -411,7 +420,7 @@ def async_track_state_added_domain( TRACK_STATE_ADDED_DOMAIN_CALLBACKS, TRACK_STATE_ADDED_DOMAIN_LISTENER, domains, - action, + job, ) return remove_listener @@ -441,10 +450,12 @@ def async_track_state_removed_domain( EVENT_STATE_CHANGED, _async_state_change_dispatcher ) + job = HassJob(action) + domains = _async_string_to_lower_list(domains) for domain in domains: - domain_callbacks.setdefault(domain, []).append(action) + domain_callbacks.setdefault(domain, []).append(job) @callback def remove_listener() -> None: @@ -454,7 +465,7 @@ def async_track_state_removed_domain( TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, TRACK_STATE_REMOVED_DOMAIN_LISTENER, domains, - action, + job, ) return remove_listener @@ -665,6 +676,8 @@ def async_track_template( """ + job = HassJob(action) + @callback def _template_changed_listener( event: Event, updates: List[TrackTemplateResult] @@ -691,8 +704,8 @@ def async_track_template( ): return - hass.async_run_job( - action, + hass.async_run_hass_job( + job, event.data.get("entity_id"), event.data.get("old_state"), event.data.get("new_state"), @@ -719,7 +732,7 @@ class _TrackTemplateResultInfo: ): """Handle removal / refresh of tracker init.""" self.hass = hass - self._action = action + self._job = HassJob(action) for track_template_ in track_templates: track_template_.template.hass = hass @@ -866,7 +879,7 @@ class _TrackTemplateResultInfo: for track_result in updates: self._last_result[track_result.template] = track_result.result - self.hass.async_run_job(self._action, event, updates) + self.hass.async_run_hass_job(self._job, event, updates) TrackTemplateResultListener = Callable[ @@ -951,6 +964,8 @@ def async_track_same_state( async_remove_state_for_cancel: Optional[CALLBACK_TYPE] = None async_remove_state_for_listener: Optional[CALLBACK_TYPE] = None + job = HassJob(action) + @callback def clear_listener() -> None: """Clear all unsub listener.""" @@ -969,7 +984,7 @@ def async_track_same_state( nonlocal async_remove_state_for_listener async_remove_state_for_listener = None clear_listener() - hass.async_run_job(action) + hass.async_run_hass_job(job) @callback def state_for_cancel_listener(event: Event) -> None: @@ -1005,14 +1020,18 @@ track_same_state = threaded_listener_factory(async_track_same_state) @callback @bind_hass def async_track_point_in_time( - hass: HomeAssistant, action: Callable[..., None], point_in_time: datetime + hass: HomeAssistant, + action: Union[HassJob, Callable[..., None]], + point_in_time: datetime, ) -> CALLBACK_TYPE: """Add a listener that fires once after a specific point in time.""" + job = action if isinstance(action, HassJob) else HassJob(action) + @callback def utc_converter(utc_now: datetime) -> None: """Convert passed in UTC now to local now.""" - hass.async_run_job(action, dt_util.as_local(utc_now)) + hass.async_run_hass_job(job, dt_util.as_local(utc_now)) return async_track_point_in_utc_time(hass, utc_converter, point_in_time) @@ -1023,16 +1042,22 @@ track_point_in_time = threaded_listener_factory(async_track_point_in_time) @callback @bind_hass def async_track_point_in_utc_time( - hass: HomeAssistant, action: Callable[..., Any], point_in_time: datetime + hass: HomeAssistant, + action: Union[HassJob, Callable[..., None]], + point_in_time: datetime, ) -> CALLBACK_TYPE: """Add a listener that fires once after a specific point in UTC time.""" # Ensure point_in_time is UTC utc_point_in_time = dt_util.as_utc(point_in_time) + # Since this is called once, we accept a HassJob so we can avoid + # having to figure out how to call the action every time its called. + job = action if isinstance(action, HassJob) else HassJob(action) + cancel_callback = hass.loop.call_at( hass.loop.time() + point_in_time.timestamp() - time.time(), - hass.async_run_job, - action, + hass.async_run_hass_job, + job, utc_point_in_time, ) @@ -1050,7 +1075,7 @@ track_point_in_utc_time = threaded_listener_factory(async_track_point_in_utc_tim @callback @bind_hass def async_call_later( - hass: HomeAssistant, delay: float, action: Callable[..., None] + hass: HomeAssistant, delay: float, action: Union[HassJob, Callable[..., None]] ) -> CALLBACK_TYPE: """Add a listener that is called in .""" return async_track_point_in_utc_time( @@ -1071,6 +1096,8 @@ def async_track_time_interval( """Add a listener that fires repetitively at every timedelta interval.""" remove = None + job = HassJob(action) + def next_interval() -> datetime: """Return the next interval.""" return dt_util.utcnow() + interval @@ -1080,7 +1107,7 @@ def async_track_time_interval( """Handle elapsed intervals.""" nonlocal remove remove = async_track_point_in_utc_time(hass, interval_listener, next_interval()) - hass.async_run_job(action, now) + hass.async_run_hass_job(job, now) remove = async_track_point_in_utc_time(hass, interval_listener, next_interval()) @@ -1196,6 +1223,8 @@ def async_track_utc_time_change( local: bool = False, ) -> CALLBACK_TYPE: """Add a listener that will fire if time matches a pattern.""" + + job = HassJob(action) # We do not have to wrap the function with time pattern matching logic # if no pattern given if all(val is None for val in (hour, minute, second)): @@ -1203,7 +1232,7 @@ def async_track_utc_time_change( @callback def time_change_listener(event: Event) -> None: """Fire every time event that comes in.""" - hass.async_run_job(action, event.data[ATTR_NOW]) + hass.async_run_hass_job(job, event.data[ATTR_NOW]) return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener) @@ -1233,7 +1262,7 @@ def async_track_utc_time_change( nonlocal next_time, cancel_callback now = pattern_utc_now() - hass.async_run_job(action, dt_util.as_local(now) if local else now) + hass.async_run_hass_job(job, dt_util.as_local(now) if local else now) calculate_next(now + timedelta(seconds=1)) diff --git a/homeassistant/helpers/update_coordinator.py b/homeassistant/helpers/update_coordinator.py index a43c15dbfbd2..2abe7b8c6b75 100644 --- a/homeassistant/helpers/update_coordinator.py +++ b/homeassistant/helpers/update_coordinator.py @@ -9,7 +9,7 @@ import urllib.error import aiohttp import requests -from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback from homeassistant.helpers import entity, event from homeassistant.util.dt import utcnow @@ -48,6 +48,7 @@ class DataUpdateCoordinator(Generic[T]): self.data: Optional[T] = None self._listeners: List[CALLBACK_TYPE] = [] + self._job = HassJob(self._handle_refresh_interval) self._unsub_refresh: Optional[CALLBACK_TYPE] = None self._request_refresh_task: Optional[asyncio.TimerHandle] = None self.last_update_success = True @@ -108,7 +109,7 @@ class DataUpdateCoordinator(Generic[T]): # as long as the update process takes less than a second self._unsub_refresh = event.async_track_point_in_utc_time( self.hass, - self._handle_refresh_interval, + self._job, utcnow().replace(microsecond=0) + self.update_interval, ) diff --git a/tests/components/homeassistant/test_init.py b/tests/components/homeassistant/test_init.py index b2ca49712d1f..3ad3ef76483a 100644 --- a/tests/components/homeassistant/test_init.py +++ b/tests/components/homeassistant/test_init.py @@ -269,7 +269,7 @@ async def test_turn_on_to_not_block_for_domains_without_service(hass): "homeassistant.core.ServiceRegistry.async_call", return_value=None, ) as mock_call: - await service.func(service_call) + await service.job.target(service_call) assert mock_call.call_count == 2 assert mock_call.call_args_list[0][0] == ( diff --git a/tests/conftest.py b/tests/conftest.py index 25572a2269be..64bcb8dc9517 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -415,6 +415,10 @@ def legacy_patchable_time(): # Ensure point_in_time is UTC point_in_time = event.dt_util.as_utc(point_in_time) + # Since this is called once, we accept a HassJob so we can avoid + # having to figure out how to call the action every time its called. + job = action if isinstance(action, ha.HassJob) else ha.HassJob(action) + @ha.callback def point_in_time_listener(event): """Listen for matching time_changed events.""" @@ -431,7 +435,7 @@ def legacy_patchable_time(): setattr(point_in_time_listener, "run", True) async_unsub() - hass.async_run_job(action, now) + hass.async_run_hass_job(job, now) async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED, point_in_time_listener) @@ -443,6 +447,8 @@ def legacy_patchable_time(): hass, action, hour=None, minute=None, second=None, local=False ): """Add a listener that will fire if time matches a pattern.""" + + job = ha.HassJob(action) # We do not have to wrap the function with time pattern matching logic # if no pattern given if all(val is None for val in (hour, minute, second)): @@ -450,7 +456,7 @@ def legacy_patchable_time(): @ha.callback def time_change_listener(ev) -> None: """Fire every time event that comes in.""" - hass.async_run_job(action, ev.data[ATTR_NOW]) + hass.async_run_hass_job(job, ev.data[ATTR_NOW]) return hass.bus.async_listen(EVENT_TIME_CHANGED, time_change_listener) @@ -487,8 +493,8 @@ def legacy_patchable_time(): last_now = now if next_time <= now: - hass.async_run_job( - action, event.dt_util.as_local(now) if local else now + hass.async_run_hass_job( + job, event.dt_util.as_local(now) if local else now ) calculate_next(now + datetime.timedelta(seconds=1)) diff --git a/tests/test_core.py b/tests/test_core.py index 1ae1f32a10a8..402b43b7d11e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -48,43 +48,43 @@ def test_split_entity_id(): assert ha.split_entity_id("domain.object_id") == ["domain", "object_id"] -def test_async_add_job_schedule_callback(): +def test_async_add_hass_job_schedule_callback(): """Test that we schedule coroutines and add jobs to the job pool.""" hass = MagicMock() job = MagicMock() - ha.HomeAssistant.async_add_job(hass, ha.callback(job)) + ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(ha.callback(job))) assert len(hass.loop.call_soon.mock_calls) == 1 assert len(hass.loop.create_task.mock_calls) == 0 assert len(hass.add_job.mock_calls) == 0 -def test_async_add_job_schedule_partial_callback(): +def test_async_add_hass_job_schedule_partial_callback(): """Test that we schedule partial coros and add jobs to the job pool.""" hass = MagicMock() job = MagicMock() partial = functools.partial(ha.callback(job)) - ha.HomeAssistant.async_add_job(hass, partial) + ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(partial)) assert len(hass.loop.call_soon.mock_calls) == 1 assert len(hass.loop.create_task.mock_calls) == 0 assert len(hass.add_job.mock_calls) == 0 -def test_async_add_job_schedule_coroutinefunction(loop): +def test_async_add_hass_job_schedule_coroutinefunction(loop): """Test that we schedule coroutines and add jobs to the job pool.""" hass = MagicMock(loop=MagicMock(wraps=loop)) async def job(): pass - ha.HomeAssistant.async_add_job(hass, job) + ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(job)) assert len(hass.loop.call_soon.mock_calls) == 0 assert len(hass.loop.create_task.mock_calls) == 1 assert len(hass.add_job.mock_calls) == 0 -def test_async_add_job_schedule_partial_coroutinefunction(loop): +def test_async_add_hass_job_schedule_partial_coroutinefunction(loop): """Test that we schedule partial coros and add jobs to the job pool.""" hass = MagicMock(loop=MagicMock(wraps=loop)) @@ -93,20 +93,20 @@ def test_async_add_job_schedule_partial_coroutinefunction(loop): partial = functools.partial(job) - ha.HomeAssistant.async_add_job(hass, partial) + ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(partial)) assert len(hass.loop.call_soon.mock_calls) == 0 assert len(hass.loop.create_task.mock_calls) == 1 assert len(hass.add_job.mock_calls) == 0 -def test_async_add_job_add_threaded_job_to_pool(): +def test_async_add_job_add_hass_threaded_job_to_pool(): """Test that we schedule coroutines and add jobs to the job pool.""" hass = MagicMock() def job(): pass - ha.HomeAssistant.async_add_job(hass, job) + ha.HomeAssistant.async_add_hass_job(hass, ha.HassJob(job)) assert len(hass.loop.call_soon.mock_calls) == 0 assert len(hass.loop.create_task.mock_calls) == 0 assert len(hass.loop.run_in_executor.mock_calls) == 1 @@ -125,7 +125,7 @@ def test_async_create_task_schedule_coroutine(loop): assert len(hass.add_job.mock_calls) == 0 -def test_async_run_job_calls_callback(): +def test_async_run_hass_job_calls_callback(): """Test that the callback annotation is respected.""" hass = MagicMock() calls = [] @@ -133,12 +133,12 @@ def test_async_run_job_calls_callback(): def job(): calls.append(1) - ha.HomeAssistant.async_run_job(hass, ha.callback(job)) + ha.HomeAssistant.async_run_hass_job(hass, ha.HassJob(ha.callback(job))) assert len(calls) == 1 assert len(hass.async_add_job.mock_calls) == 0 -def test_async_run_job_delegates_non_async(): +def test_async_run_hass_job_delegates_non_async(): """Test that the callback annotation is respected.""" hass = MagicMock() calls = [] @@ -146,9 +146,9 @@ def test_async_run_job_delegates_non_async(): def job(): calls.append(1) - ha.HomeAssistant.async_run_job(hass, job) + ha.HomeAssistant.async_run_hass_job(hass, ha.HassJob(job)) assert len(calls) == 0 - assert len(hass.async_add_job.mock_calls) == 1 + assert len(hass.async_add_hass_job.mock_calls) == 1 def test_stage_shutdown():